From 85d4e08386012f3fd8a35431680d7d3a16ad5ea0 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson <cont-frb@ulund.org> Date: Fri, 12 Jan 2018 06:26:03 +0100 Subject: [PATCH] bayesian dropout nn prior --- flux/nn_prior/nn_prior.jl | 6 +- flux/nn_prior/nn_prior_bayesian.jl | 199 +++++++++++++++++++++++++++++ flux/nn_prior/utilities.jl | 20 +-- 3 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 flux/nn_prior/nn_prior_bayesian.jl diff --git a/flux/nn_prior/nn_prior.jl b/flux/nn_prior/nn_prior.jl index 6107bd7..22567be 100644 --- a/flux/nn_prior/nn_prior.jl +++ b/flux/nn_prior/nn_prior.jl @@ -6,7 +6,7 @@ using OrdinaryDiffEq, LTVModels using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian using Flux.Optimise: Param, optimiser, RMSProp, expdecay wdecay = 0.0 - num_params = 20 + num_params = 50 stepsize = 0.05 n_bootstrap = 1 num_montecarlo = 20 @@ -119,11 +119,13 @@ end # Produce figures # pyplot(reuse=false, show=false) # figs1 = fit_system(1, sys, SystemD, 4, doplot=false); #savefigures(figs1..., 1) -# figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=false); #savefigures(figs2..., 2) +figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2) # figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false); #savefigures(figs3..., 3) # figs4 = fit_system(1, sys, AffineSystem, doplot=true); #savefigures(figs4..., 4) # error() +## + # # Monte-Carlo evaluation 1 # res = map(1:num_montecarlo) do it # r1 = @spawn fit_system(it, sys, System) diff --git a/flux/nn_prior/nn_prior_bayesian.jl b/flux/nn_prior/nn_prior_bayesian.jl new file mode 100644 index 0000000..4189531 --- /dev/null +++ b/flux/nn_prior/nn_prior_bayesian.jl @@ -0,0 +1,199 @@ +cd(@__DIR__) +# import Iterators +using OrdinaryDiffEq, LTVModels +@everywhere begin + using Flux, ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters + using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian + using Flux.Optimise: Param, optimiser, RMSProp, expdecay + wdecay = 0.0 + num_params = 50 + stepsize = 0.05 + n_bootstrap = 1 + num_montecarlo = 5 +end + +inoffice() = gethostname() ∈ ["billman", "Battostation"] +inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl")) +if !isdefined(:AbstractEnsembleSystem) + inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl") + @everywhere include("utilities.jl") + @everywhere include("two_link_sys.jl") + @everywhere include("linear_sys.jl") +end + +sys = TwoLinkSys(N=1000, h=0.02, σ0 = 0.1)#; warn("Added noise") +# sys = LinearSys(1,N=100, h=0.2, σ0 = 1, n = 10, ns = 10) + +@everywhere function fit_model(opt, loss, m, x, y, u, xv, yv, uv, sys, modeltype; + iters = 2000, + doplot = true, + batch_size = 50) + + trace = History(Float64) + vtrace = History(Float64) + batcher = batchview(shuffleobs((x,y)), batch_size) + dataset = ncycle(batcher, iters) + iter = 0 + function evalcallback() + iter += 1 + testmode!(m) + l = loss(x, y).data[1] + push!(trace,iter,l) + push!(vtrace, iter, loss(xv,yv).data[1]) + if iter % 50 == 0 + if doplot + println("Iter: $iter, Loss: ", l) + plot(trace,reuse=true,show=false, lab="Train", layout=3, subplot=1, size=(1400,1000)) + plot!(vtrace,show=false, lab="Validation", subplot=1, yscale=:log10) + plot!(y', subplot=2:3, lab="y") + plot!(m(x).data', l=:dash, subplot=2:3, show=true, lab="Pred") + end + end + testmode!(m, false) + truncate!(m) + end + train!(loss, dataset, opt, cb = evalcallback) + results = Dict{Symbol, Any}() + @pack results = m, trace, vtrace, modeltype, sys + results +end + + + +@everywhere function fit_system(seed, sys, modeltype, jacprop = 0; doplot=false) + x,y,u = generate_data(sys, modeltype, seed) + iters = modeltype==System ? 2000 : 1000 + if jacprop > 0 + iters = round(Int, iters / sqrt(jacprop+1)) + n,m = sys.ns,sys.n + R1,R2 = 0.01eye(n*(n+m)), 10eye(n) # Fast + P0 = 10000R1 + model = LTVModels.fit_model(LTVModels.KalmanModel, x,u,R1,R2,P0, extend=true) + xt,ut,yt = [x;u], u, y + for i = 1:jacprop + xa = x .+ std(x,2)/10 .* randn(size(x)) + ua = u .+ std(u,2)/10 .* randn(size(u)) + ya = LTVModels.predict(model, xa,ua) + xt = [xt [xa;ua]] + ut = [ut ua] + yt = modeltype ∈ [VelSystem, VelSystemD] ? [yt ya[3:4,:]] : [yt ya] + end + x = [x;u] + else + x = [x;u] + xt,ut,yt = x,u,y + end + + xv,yv,uv = generate_data(sys, modeltype, seed+1) + xv = [xv;uv] + + activation = relu + + m = modeltype(sys, num_params, activation) + opt = [ADAM(params(m), stepsize, decay=0.005); [expdecay(Param(p), wdecay) for p in params(m) if p isa AbstractMatrix]] + results = fit_model(opt, loss(m), m, xt, yt, ut, xv, yv, uv, sys,modeltype, iters=iters ÷ jacprop, doplot=doplot, batch_size = size(yt,2) ÷ jacprop) + @pack results = x, u, y, xv, uv, yv + results = [results] + if doplot + Jm, Js, Jtrue = all_jacobians(results) + fig_jac = plot_jacobians(Jm, Js, Jtrue); #gui(); + fig_time = plotresults(results); #gui(); + fig_eig = plot_eigvals(results); #gui(); + return fig_jac, fig_time, fig_eig + end + results +end + + +function savefigures(f1,f2,f3,n) + savefig(f1,"/local/home/fredrikb/papers/nn_prior/figs/jacs$(n).pdf") + savefig(f2,"/local/home/fredrikb/papers/nn_prior/figs/timeseries$(n).pdf") + savefig(f3,"/local/home/fredrikb/papers/nn_prior/figs/eigvals$(n).pdf") +end + + +# Produce figures +# pyplot(reuse=false, show=false) +# figs1 = fit_system(1, sys, SystemD, 4, doplot=false); #savefigures(figs1..., 1) +figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2) +# figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false); #savefigures(figs3..., 3) +# figs4 = fit_system(1, sys, AffineSystem, doplot=true); #savefigures(figs4..., 4) +# error() + +## + +res = map(1:num_montecarlo) do it + r1 = @spawn fit_system(it, sys, System, 2) + r2 = @spawn fit_system(it, sys, DiffSystem, 2) + r3 = @spawn fit_system(it, sys, SystemD, 2) + r4 = @spawn fit_system(it, sys, DiffSystemD, 2) + + + println("Done with montecarlo run $it") + r1,r2,r3,r4 +end +res = [(fetch.(rs)...) for rs in res] + +open(file->serialize(file, res), "res","w") # Save +error() + +## +sleep(60*3) +## +using StatPlots +res = open(file->deserialize(file), "res") # Load +sys = res[1][1][1][:sys] +nr = length(res[1]) ÷ 2 +labelvec = ["f" "g"] +infostring = @sprintf("Num hidden: %d, sigma: %2.2f, Montecarlo: %d", num_params, sys.σ0, num_montecarlo) + +# pgfplots(size=(600,300), show=true) + +pred = hcat([eval_pred.(get_res(res,i), true) for i ∈ 1:nr]...) +sim = hcat([eval_sim.(get_res(res,i), true) for i ∈ 1:nr]...) +predp = hcat([eval_pred.(get_res(res,i), true) for i ∈ nr+1:2nr]...) +simp = hcat([eval_sim.(get_res(res,i), true) for i ∈ nr+1:2nr]...) +vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Prediction RMS)", reuse=false, c=:red) +violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue) +vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Simulation RMS)", c=:red) +violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue, legend=false) +plot(vio1,vio2,title=infostring); gui() +# savefig2("/local/home/fredrikb/papers/nn_prior/figs/valerrB.tex") + +# jacerrtrain = hcat([eval_jac.(get_res(res,i), false) for i ∈ 1:nr]...) +jacerr = hcat([eval_jac.(get_res(res,i), true) for i ∈ 1:nr]...) +jacerrp = hcat([eval_jac.(get_res(res,i), true) for i ∈ nr+1:2nr]...) +violin(identity.(jacerr), side=:left, xticks=(1:nr,labelvec), title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red) +violin!((1:nr)',identity.(jacerrp), side=:right, xticks=(1:nr,labelvec), ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui() +# savefig2("/local/home/fredrikb/papers/nn_prior/figs/jacerrB.tex") + +## + +# TODO: ==================================================================================== +# Simulation performance +# Prediction performance +# Jacobian error +# The above on both train and test data +# Compare likelihood (weight by confidence) +# Learn time difference +# Learn states directly +# Learn only velocity +# LayerNorm on/off +# Weight decay on/off + +# Compare true vs ad-hoc bootstrap +num_workers = 1; addprocs([(@sprintf("philon-%2.2d",i),num_workers) for i in [2:4; 6:10]]);addprocs([(@sprintf("ktesibios-%2.2d",i),num_workers) for i in 1:9]);addprocs([(@sprintf("heron-%2.2d",i),num_workers) for i in [1,2,3,4,6,11]]) + + + +vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""], ylabel="log(Prediction RMS)", reuse=false, c=:red) +violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""], c=:blue) +vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""], ylabel="log(Simulation RMS)", c=:red) +violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""], c=:blue, legend=false) +plot(vio1,vio2,title=infostring); gui() + + +jacerr = hcat([eval_jac.(get_res(res,i), true) for i ∈ 1:nr]...) +jacerrp = hcat([eval_jac.(get_res(res,i), true) for i ∈ nr+1:2nr]...) +violin(identity.(jacerr), side=:left, title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red) +violin!((1:nr)',identity.(jacerrp), side=:right, ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui() diff --git a/flux/nn_prior/utilities.jl b/flux/nn_prior/utilities.jl index ac31e44..dab3e0c 100644 --- a/flux/nn_prior/utilities.jl +++ b/flux/nn_prior/utilities.jl @@ -11,7 +11,7 @@ function System(sys, num_params, activation) @unpack n,ns = sys ny = ns np = num_params - m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny)) + m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny)) System(m, sys) end (m::System)(x) = m.m(x) @@ -25,7 +25,7 @@ function DiffSystem(sys, num_params, activation) @unpack n,ns = sys ny = ns np = num_params - m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny)) + m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny)) DiffSystem(m, sys) end (m::DiffSystem)(x) = m.m(x)+x[1:m.sys.ns,:] @@ -39,7 +39,7 @@ function VelSystem(sys, num_params, activation) @unpack n,ns = sys ny = n np = num_params - m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny)) + m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny)) VelSystem(m, sys) end (m::VelSystem)(x) = m.m(x) @@ -56,9 +56,9 @@ end for (S1,S2) in zip([:SystemD, :DiffSystemD, :VelSystemD],[:System, :DiffSystem, :VelSystem]) @eval $(S1)(args...) = $(S1)($(S2)(args...)) # @show S1 - @eval StatsBase.predict(ms::Vector{$S1}, x, mc=10) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3)) - @eval simulate(ms::Vector{$S1}, x, mc=10) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc) - @eval Flux.jacobian(ms::Vector{$S1}, x, mc=10) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3)) + @eval StatsBase.predict(ms::Vector{$S1}, x, mc=100) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3)) + @eval simulate(ms::Vector{$S1}, x, mc=100) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc) + @eval Flux.jacobian(ms::Vector{$S1}, x, mc=100) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3)) @eval (s::$(S1))(x) = s.s(x) end @@ -177,7 +177,8 @@ function Flux.jacobian(ms::EnsembleVelSystem, x, testmode=true) squeeze(mean(jacmat, 3), 3), squeeze(std(jacmat, 3), 3) end -models(results) = [r[:m] for r in results] +models(results::AbstractVector) = [r[:m] for r in results] +models(results::Associative) = [r[:m]] function all_jacobians(results, eval=false) @unpack x,u,xv,uv,sys,modeltype = results[1] @@ -194,6 +195,7 @@ function all_jacobians(results, eval=false) Jm, Js, Jtrue end + function plot_jacobians(Jm, Js, Jtrue) N = size(Jm,1) colors = [HSV(h,1,0.8) for h in linspace(0,254,N)] @@ -271,8 +273,8 @@ function plot_prediction(fig, results, eval=false) ms = models(results) yh, bounds = predict(ms, x) for i = 1:size(yh,1) - if size(bounds[1],1) >= i - plot!(fig, yh[i,:], fillrange = getindex.(bounds,i,:), fillalpha=0.3, subplot=i, lab="Prediction") + if size(bounds,1) >= i + plot!(fig, yh[i,:], ribbon = 2bounds[i,:], fillalpha=0.3, subplot=i, lab="Prediction") else plot!(fig, yh[i,:], subplot=i, lab="Prediction") end -- GitLab