From 85d4e08386012f3fd8a35431680d7d3a16ad5ea0 Mon Sep 17 00:00:00 2001
From: Fredrik Bagge Carlson <cont-frb@ulund.org>
Date: Fri, 12 Jan 2018 06:26:03 +0100
Subject: [PATCH] bayesian dropout nn prior

---
 flux/nn_prior/nn_prior.jl          |   6 +-
 flux/nn_prior/nn_prior_bayesian.jl | 199 +++++++++++++++++++++++++++++
 flux/nn_prior/utilities.jl         |  20 +--
 3 files changed, 214 insertions(+), 11 deletions(-)
 create mode 100644 flux/nn_prior/nn_prior_bayesian.jl

diff --git a/flux/nn_prior/nn_prior.jl b/flux/nn_prior/nn_prior.jl
index 6107bd7..22567be 100644
--- a/flux/nn_prior/nn_prior.jl
+++ b/flux/nn_prior/nn_prior.jl
@@ -6,7 +6,7 @@ using OrdinaryDiffEq, LTVModels
     using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian
     using Flux.Optimise: Param, optimiser, RMSProp, expdecay
     wdecay         = 0.0
-    num_params     = 20
+    num_params     = 50
     stepsize       = 0.05
     n_bootstrap    = 1
     num_montecarlo = 20
@@ -119,11 +119,13 @@ end
 # Produce figures
 # pyplot(reuse=false, show=false)
 # figs1 = fit_system(1, sys, SystemD, 4, doplot=false);     #savefigures(figs1..., 1)
-# figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=false); #savefigures(figs2..., 2)
+figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2)
 # figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false);  #savefigures(figs3..., 3)
 # figs4 = fit_system(1, sys, AffineSystem, doplot=true);  #savefigures(figs4..., 4)
 # error()
 
+##
+
 # # Monte-Carlo evaluation 1
 # res = map(1:num_montecarlo) do it
 #     r1 = @spawn fit_system(it, sys, System)
diff --git a/flux/nn_prior/nn_prior_bayesian.jl b/flux/nn_prior/nn_prior_bayesian.jl
new file mode 100644
index 0000000..4189531
--- /dev/null
+++ b/flux/nn_prior/nn_prior_bayesian.jl
@@ -0,0 +1,199 @@
+cd(@__DIR__)
+# import Iterators
+using OrdinaryDiffEq, LTVModels
+@everywhere begin
+    using Flux, ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters
+    using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian
+    using Flux.Optimise: Param, optimiser, RMSProp, expdecay
+    wdecay         = 0.0
+    num_params     = 50
+    stepsize       = 0.05
+    n_bootstrap    = 1
+    num_montecarlo = 5
+end
+
+inoffice() = gethostname() ∈ ["billman", "Battostation"]
+inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl"))
+if !isdefined(:AbstractEnsembleSystem)
+    inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl")
+    @everywhere include("utilities.jl")
+    @everywhere include("two_link_sys.jl")
+    @everywhere include("linear_sys.jl")
+end
+
+sys = TwoLinkSys(N=1000, h=0.02, σ0 = 0.1)#; warn("Added noise")
+# sys = LinearSys(1,N=100, h=0.2, σ0 = 1, n = 10, ns = 10)
+
+@everywhere function fit_model(opt, loss, m, x, y, u, xv, yv, uv, sys, modeltype;
+    iters                 = 2000,
+    doplot                = true,
+    batch_size            = 50)
+
+    trace   = History(Float64)
+    vtrace  = History(Float64)
+    batcher = batchview(shuffleobs((x,y)), batch_size)
+    dataset = ncycle(batcher, iters)
+    iter    = 0
+    function evalcallback()
+        iter += 1
+        testmode!(m)
+        l = loss(x, y).data[1]
+        push!(trace,iter,l)
+        push!(vtrace, iter, loss(xv,yv).data[1])
+        if iter % 50 == 0
+            if doplot
+                println("Iter: $iter, Loss: ", l)
+                plot(trace,reuse=true,show=false, lab="Train", layout=3, subplot=1, size=(1400,1000))
+                plot!(vtrace,show=false, lab="Validation", subplot=1, yscale=:log10)
+                plot!(y', subplot=2:3, lab="y")
+                plot!(m(x).data', l=:dash, subplot=2:3, show=true, lab="Pred")
+            end
+        end
+        testmode!(m, false)
+        truncate!(m)
+    end
+    train!(loss, dataset, opt, cb = evalcallback)
+    results           = Dict{Symbol, Any}()
+    @pack results = m, trace, vtrace, modeltype, sys
+    results
+end
+
+
+
+@everywhere function fit_system(seed, sys, modeltype, jacprop = 0; doplot=false)
+    x,y,u    = generate_data(sys, modeltype, seed)
+    iters = modeltype==System ? 2000 : 1000
+    if jacprop > 0
+        iters = round(Int, iters / sqrt(jacprop+1))
+        n,m = sys.ns,sys.n
+        R1,R2 = 0.01eye(n*(n+m)), 10eye(n) # Fast
+        P0 = 10000R1
+        model = LTVModels.fit_model(LTVModels.KalmanModel, x,u,R1,R2,P0, extend=true)
+        xt,ut,yt = [x;u], u, y
+        for i = 1:jacprop
+            xa = x .+ std(x,2)/10 .* randn(size(x))
+            ua = u .+ std(u,2)/10 .* randn(size(u))
+            ya = LTVModels.predict(model, xa,ua)
+            xt = [xt [xa;ua]]
+            ut = [ut ua]
+            yt = modeltype ∈ [VelSystem, VelSystemD] ? [yt ya[3:4,:]] : [yt ya]
+        end
+        x  = [x;u]
+    else
+        x        = [x;u]
+        xt,ut,yt = x,u,y
+    end
+
+    xv,yv,uv = generate_data(sys, modeltype, seed+1)
+    xv       = [xv;uv]
+
+    activation = relu
+
+    m          = modeltype(sys, num_params, activation)
+    opt        = [ADAM(params(m), stepsize, decay=0.005); [expdecay(Param(p), wdecay) for p in params(m) if p isa AbstractMatrix]]
+    results    = fit_model(opt, loss(m), m, xt, yt, ut, xv, yv, uv, sys,modeltype, iters=iters ÷ jacprop, doplot=doplot, batch_size = size(yt,2) ÷ jacprop)
+    @pack results = x, u, y, xv, uv, yv
+    results = [results]
+    if doplot
+        Jm, Js, Jtrue = all_jacobians(results)
+        fig_jac = plot_jacobians(Jm, Js, Jtrue); #gui();
+        fig_time = plotresults(results); #gui();
+        fig_eig = plot_eigvals(results); #gui();
+        return fig_jac, fig_time, fig_eig
+    end
+    results
+end
+
+
+function savefigures(f1,f2,f3,n)
+    savefig(f1,"/local/home/fredrikb/papers/nn_prior/figs/jacs$(n).pdf")
+    savefig(f2,"/local/home/fredrikb/papers/nn_prior/figs/timeseries$(n).pdf")
+    savefig(f3,"/local/home/fredrikb/papers/nn_prior/figs/eigvals$(n).pdf")
+end
+
+
+# Produce figures
+# pyplot(reuse=false, show=false)
+# figs1 = fit_system(1, sys, SystemD, 4, doplot=false);     #savefigures(figs1..., 1)
+figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2)
+# figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false);  #savefigures(figs3..., 3)
+# figs4 = fit_system(1, sys, AffineSystem, doplot=true);  #savefigures(figs4..., 4)
+# error()
+
+##
+
+res = map(1:num_montecarlo) do it
+    r1 = @spawn fit_system(it, sys, System, 2)
+    r2 = @spawn fit_system(it, sys, DiffSystem, 2)
+    r3 = @spawn fit_system(it, sys, SystemD, 2)
+    r4 = @spawn fit_system(it, sys, DiffSystemD, 2)
+
+
+    println("Done with montecarlo run $it")
+    r1,r2,r3,r4
+end
+res = [(fetch.(rs)...) for rs in res]
+
+open(file->serialize(file, res), "res","w") # Save
+error()
+
+##
+sleep(60*3)
+##
+using StatPlots
+res = open(file->deserialize(file), "res") # Load
+sys = res[1][1][1][:sys]
+nr = length(res[1]) ÷ 2
+labelvec = ["f" "g"]
+infostring = @sprintf("Num hidden: %d, sigma: %2.2f, Montecarlo: %d", num_params, sys.σ0, num_montecarlo)
+
+# pgfplots(size=(600,300), show=true)
+
+pred  = hcat([eval_pred.(get_res(res,i), true) for i ∈ 1:nr]...)
+sim   = hcat([eval_sim.(get_res(res,i), true) for i ∈ 1:nr]...)
+predp = hcat([eval_pred.(get_res(res,i), true) for i ∈ nr+1:2nr]...)
+simp  = hcat([eval_sim.(get_res(res,i), true) for i ∈ nr+1:2nr]...)
+vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Prediction RMS)", reuse=false, c=:red)
+violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue)
+vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Simulation RMS)", c=:red)
+violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue, legend=false)
+plot(vio1,vio2,title=infostring); gui()
+# savefig2("/local/home/fredrikb/papers/nn_prior/figs/valerrB.tex")
+
+# jacerrtrain = hcat([eval_jac.(get_res(res,i), false) for i ∈ 1:nr]...)
+jacerr = hcat([eval_jac.(get_res(res,i), true) for i ∈ 1:nr]...)
+jacerrp = hcat([eval_jac.(get_res(res,i), true) for i ∈ nr+1:2nr]...)
+violin(identity.(jacerr), side=:left, xticks=(1:nr,labelvec), title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red)
+violin!((1:nr)',identity.(jacerrp), side=:right, xticks=(1:nr,labelvec), ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui()
+# savefig2("/local/home/fredrikb/papers/nn_prior/figs/jacerrB.tex")
+
+##
+
+# TODO: ====================================================================================
+# Simulation performance
+# Prediction performance
+# Jacobian error
+# The above on both train and test data
+# Compare likelihood (weight by confidence)
+# Learn time difference
+# Learn states directly
+# Learn only velocity
+# LayerNorm on/off
+# Weight decay on/off
+
+# Compare true vs ad-hoc bootstrap
+num_workers = 1; addprocs([(@sprintf("philon-%2.2d",i),num_workers) for i in [2:4; 6:10]]);addprocs([(@sprintf("ktesibios-%2.2d",i),num_workers) for i in 1:9]);addprocs([(@sprintf("heron-%2.2d",i),num_workers) for i in [1,2,3,4,6,11]])
+
+
+
+vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""], ylabel="log(Prediction RMS)", reuse=false, c=:red)
+violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""], c=:blue)
+vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""], ylabel="log(Simulation RMS)", c=:red)
+violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""], c=:blue, legend=false)
+plot(vio1,vio2,title=infostring); gui()
+
+
+jacerr = hcat([eval_jac.(get_res(res,i), true) for i ∈ 1:nr]...)
+jacerrp = hcat([eval_jac.(get_res(res,i), true) for i ∈ nr+1:2nr]...)
+violin(identity.(jacerr), side=:left, title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red)
+violin!((1:nr)',identity.(jacerrp), side=:right, ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui()
diff --git a/flux/nn_prior/utilities.jl b/flux/nn_prior/utilities.jl
index ac31e44..dab3e0c 100644
--- a/flux/nn_prior/utilities.jl
+++ b/flux/nn_prior/utilities.jl
@@ -11,7 +11,7 @@ function System(sys, num_params, activation)
     @unpack n,ns = sys
     ny = ns
     np = num_params
-    m  = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
+    m  = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
     System(m, sys)
 end
 (m::System)(x) = m.m(x)
@@ -25,7 +25,7 @@ function DiffSystem(sys, num_params, activation)
     @unpack n,ns = sys
     ny = ns
     np = num_params
-    m  = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
+    m  = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
     DiffSystem(m, sys)
 end
 (m::DiffSystem)(x) = m.m(x)+x[1:m.sys.ns,:]
@@ -39,7 +39,7 @@ function VelSystem(sys, num_params, activation)
     @unpack n,ns = sys
     ny = n
     np = num_params
-    m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
+    m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
     VelSystem(m, sys)
 end
 (m::VelSystem)(x) = m.m(x)
@@ -56,9 +56,9 @@ end
 for (S1,S2) in zip([:SystemD, :DiffSystemD, :VelSystemD],[:System, :DiffSystem, :VelSystem])
     @eval $(S1)(args...) = $(S1)($(S2)(args...))
     # @show S1
-    @eval StatsBase.predict(ms::Vector{$S1}, x, mc=10) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
-    @eval simulate(ms::Vector{$S1}, x, mc=10) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc)
-    @eval Flux.jacobian(ms::Vector{$S1}, x, mc=10) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
+    @eval StatsBase.predict(ms::Vector{$S1}, x, mc=100) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
+    @eval simulate(ms::Vector{$S1}, x, mc=100) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc)
+    @eval Flux.jacobian(ms::Vector{$S1}, x, mc=100) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
     @eval (s::$(S1))(x) = s.s(x)
 end
 
@@ -177,7 +177,8 @@ function Flux.jacobian(ms::EnsembleVelSystem, x, testmode=true)
     squeeze(mean(jacmat, 3), 3), squeeze(std(jacmat, 3), 3)
 end
 
-models(results) = [r[:m] for r in results]
+models(results::AbstractVector) = [r[:m] for r in results]
+models(results::Associative) = [r[:m]]
 
 function all_jacobians(results, eval=false)
     @unpack x,u,xv,uv,sys,modeltype = results[1]
@@ -194,6 +195,7 @@ function all_jacobians(results, eval=false)
     Jm, Js, Jtrue
 end
 
+
 function plot_jacobians(Jm, Js, Jtrue)
     N = size(Jm,1)
     colors = [HSV(h,1,0.8) for h in linspace(0,254,N)]
@@ -271,8 +273,8 @@ function plot_prediction(fig, results, eval=false)
     ms = models(results)
     yh, bounds = predict(ms, x)
     for i = 1:size(yh,1)
-        if size(bounds[1],1) >= i
-            plot!(fig, yh[i,:], fillrange = getindex.(bounds,i,:), fillalpha=0.3, subplot=i, lab="Prediction")
+        if size(bounds,1) >= i
+            plot!(fig, yh[i,:], ribbon = 2bounds[i,:], fillalpha=0.3, subplot=i, lab="Prediction")
         else
             plot!(fig, yh[i,:], subplot=i, lab="Prediction")
         end
-- 
GitLab