Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • sub1
2 results

pf_vec.jl

Blame
  • pf_vec.jl 5.21 KiB
    using StatsBase, Plots, Distributions, StaticArrays, Base.Test, TimerOutputs
    
    function init_pf{T}(x0::AbstractVector{T}, N, p0)
        xprev = Vector{SVector{length(x0),T}}([x0 .+ rand(p0) for n=1:N])
        x = similar(xprev)
        w = fill(log(1/N), N)
        x,xprev,w
    end
    
    function pf!(x, xprev, w, u, y, g, f)
        N = length(x)
        if shouldresample(w)
            j = resample(w)
            f(x,xprev,u,j)
            fill!(w,log(1/N))
        else # Resample not needed
            f(x,xprev,u,1:N)
        end
        g(w,y,x)
        logsumexp!(w)
        copy!(xprev, x)
        x
    end
    
    shouldresample(w) = rand() < 0.5
    
    function weigthed_mean(x,w)
        xh = zeros(size(x[1]))
        @inbounds @simd  for i = eachindex(x)
            xh .+= x[i].*exp(w[i])
        end
        return xh
    end
    
    @testset "weigthed_mean" begin
    x = [randn(3) for i = 1:10000]
    w = ones(10000) |> logsumexp!
    @test sum(abs, weigthed_mean(x,w)) < 0.05
    end
    
    # @inline logsumexp!(w) = w .-= log(sum(exp, w))
    function logsumexp!(w)
        offset = maximum(w)
        normConstant = zero(eltype(w))
        for i = eachindex(w)
            normConstant += exp(w[i]-offset)
        end
        w .-= log(normConstant) + offset
    end
    
    @testset "logsumexp" begin
    w = randn(10)
    wc = copy(w)
    @test logsumexp!(w) ≈ wc.-log(sum(exp, wc))
    @test logsumexp!(ones(10)) ≈ fill(log(1/10),10)
    end
    
    function resample(w)
        N = length(w)
        j = Array{Int64}(N)
        bins = Array{Float64}(N)
        bins[1] = exp(w[1])
        for i = 2:N
            bins[i] = bins[i-1] + exp(w[i])
        end
        s = (rand()/N):(1/N):bins[end]
        bo = 1
        for i = 1:N
            @inbounds for b = bo:N
                if s[i] < bins[b]
                    j[i] = b
                    bo = b
                    break
                end
            end
        end
        return j
    end
    
    @testset "resample" begin
    w = logsumexp!(ones(10))
    @test resample(w) ≈ 1:10
    @test [1.,1,1,2,2,2,3,3,3] |> logsumexp! |> resample |> sum >= 56
    @test length(resample(w)) == length(w)
    for i = 1:10000
        j = randn(100) |> logsumexp! |> resample
        @test maximum(j) <= 100
        @test minimum(j) >= 1
    end
    end
    
    n = 2
    m = 2
    p = 2
    
    const pg = Distributions.MvNormal(p,1.0)
    const pf = Distributions.MvNormal(n,1.0)
    const p0 = Distributions.MvNormal(n,2.0)
    
    T = randn(n,n)
    const A = SMatrix{n,n}(T*diagm(linspace(0.5,0.99,n))/T)
    const B = @SMatrix randn(n,m)
    const C = @SMatrix randn(p,n)
    
    function f(x,xp,u,j)
        Bu = B*u
        @inbounds for i = eachindex(x)
            x[i] =  A*xp[j[i]] + Bu + rand(pf)
        end
        x
    end
    function f(x,u)
        Bu = B*u
        @inbounds for i = eachindex(x)
            x[i] =  A*x[i] .+ Bu .+ rand(pf)
        end
        x
    end
    
    function g(w,y,x)
        @inbounds for i = 1:length(w)
            w[i] += logpdf(pg, Vector(y-C*x[i]))
            w[i] = ifelse(w[i] < -1000, -1000, w[i])
        end
        w
    end
    
    function run_test()
        particle_count = [5, 10, 20, 50, 100, 200, 500, 1000, 10_000]
        time_steps = [20, 50, 100, 200]
        RMSE = zeros(length(particle_count),length(time_steps)) # Store the RMS errors
        propagated_particles = 0
        for (Ti,T) in enumerate(time_steps)
            for (Ni, N) in enumerate(particle_count)
                montecarlo_runs = maximum(particle_count)*maximum(time_steps) / T / N
                #             montecarlo_runs = 1
                x = zeros(n,T)
                y = zeros(p,T)
    
                E = sum(1:montecarlo_runs) do mc_run
                    u = randn(m,T)
                    x[:,1] = rand(p0)
                    y[:,1] = C*x[:,1] + rand(pg)
    
                    xh,xhprev,w = init_pf(x[:,1], N, p0)
                    error = 0.0
                    @timeit "pf" @inbounds for t = 1:T-1
                        # plot_particles2(xh,w,y,x,t)
                        x[:,t+1] = f([x[:,t]],u[:,t])[]
                        y[:,t+1] = C*x[:,t+1]  + rand(pg)
                        pf!(xh, xhprev, w, u[:,t], y[:,t], g, f)
                        error += sum(abs2,x[:,t]-weigthed_mean(xh,w))
                    end # t
                    √(error/T)
                end # MC
                RMSE[Ni,Ti] = E/montecarlo_runs
                propagated_particles += montecarlo_runs*N*T
                #     figure();plot([x xh])
    
                @show N
            end # N
            @show T
        end # T
        println("Propagated $propagated_particles particles")
        #
        return RMSE
    end
    
    # @enter pf!(zeros(4),zeros(4), ones(4), ones(4), ones(4), g, f)
    reset_timer!()
    @time RMSE = run_test()
    
    # Profile.print()
    function plotting(RMSE)
        time_steps     = [20, 50, 100, 200]
        particle_count = [5, 10, 20, 50, 100, 200, 500, 1000, 10_000]
        nT             = length(time_steps)
        leg            = reshape(["$(time_steps[i]) time steps" for i = 1:nT], 1,:)
        plot(particle_count,RMSE,xscale=:log10, ylabel="RMS errors", xlabel=" Number of particles", lab=leg)
    end
    
    plotting(RMSE)
    
    function plot_particles(x,w,y,xt)
        xa = reinterpret(Float64, x, (length(x[1]), length(x)))
        scatter(xa[1,:],xa[2,:], title="Particles", reuse=true, xlims=(-15,15), ylims=(-15,15), grid=false, size=(1000,1000))
        scatter!([y[1]], [y[2]], m = (:red, 5))
        scatter!([xt[1]], [xt[2]], m = (:green, 5))
        sleep(0.2)
    end
    
    function plot_particles2(x,w,y,xt,t)
        xa = reinterpret(Float64, x, (length(x[1]), length(x)))
        plot(xt', title="Particles", reuse=true,  grid=false, size=(1000,1000), layout=(2,1), ylims=(-15,15))
        plot!(y', l = (:red, 2))
        scatter!(t*ones(size(xa,2)), xa[1,:], subplot=1)
        scatter!(t*ones(size(xa,2)), xa[2,:], subplot=2)
        sleep(0.2)
    end