Skip to content
Snippets Groups Projects
Commit 85d4e083 authored by Fredrik Bagge Carlson's avatar Fredrik Bagge Carlson
Browse files

bayesian dropout nn prior

parent 9c1b96ce
Branches
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ using OrdinaryDiffEq, LTVModels
using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian
using Flux.Optimise: Param, optimiser, RMSProp, expdecay
wdecay = 0.0
num_params = 20
num_params = 50
stepsize = 0.05
n_bootstrap = 1
num_montecarlo = 20
......@@ -119,11 +119,13 @@ end
# Produce figures
# pyplot(reuse=false, show=false)
# figs1 = fit_system(1, sys, SystemD, 4, doplot=false); #savefigures(figs1..., 1)
# figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=false); #savefigures(figs2..., 2)
figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2)
# figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false); #savefigures(figs3..., 3)
# figs4 = fit_system(1, sys, AffineSystem, doplot=true); #savefigures(figs4..., 4)
# error()
##
# # Monte-Carlo evaluation 1
# res = map(1:num_montecarlo) do it
# r1 = @spawn fit_system(it, sys, System)
......
cd(@__DIR__)
# import Iterators
using OrdinaryDiffEq, LTVModels
@everywhere begin
using Flux, ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters
using Flux: back!, truncate!, treelike, train!, mse, testmode!, combine, params, jacobian
using Flux.Optimise: Param, optimiser, RMSProp, expdecay
wdecay = 0.0
num_params = 50
stepsize = 0.05
n_bootstrap = 1
num_montecarlo = 5
end
inoffice() = gethostname() ["billman", "Battostation"]
inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl"))
if !isdefined(:AbstractEnsembleSystem)
inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl")
@everywhere include("utilities.jl")
@everywhere include("two_link_sys.jl")
@everywhere include("linear_sys.jl")
end
sys = TwoLinkSys(N=1000, h=0.02, σ0 = 0.1)#; warn("Added noise")
# sys = LinearSys(1,N=100, h=0.2, σ0 = 1, n = 10, ns = 10)
@everywhere function fit_model(opt, loss, m, x, y, u, xv, yv, uv, sys, modeltype;
iters = 2000,
doplot = true,
batch_size = 50)
trace = History(Float64)
vtrace = History(Float64)
batcher = batchview(shuffleobs((x,y)), batch_size)
dataset = ncycle(batcher, iters)
iter = 0
function evalcallback()
iter += 1
testmode!(m)
l = loss(x, y).data[1]
push!(trace,iter,l)
push!(vtrace, iter, loss(xv,yv).data[1])
if iter % 50 == 0
if doplot
println("Iter: $iter, Loss: ", l)
plot(trace,reuse=true,show=false, lab="Train", layout=3, subplot=1, size=(1400,1000))
plot!(vtrace,show=false, lab="Validation", subplot=1, yscale=:log10)
plot!(y', subplot=2:3, lab="y")
plot!(m(x).data', l=:dash, subplot=2:3, show=true, lab="Pred")
end
end
testmode!(m, false)
truncate!(m)
end
train!(loss, dataset, opt, cb = evalcallback)
results = Dict{Symbol, Any}()
@pack results = m, trace, vtrace, modeltype, sys
results
end
@everywhere function fit_system(seed, sys, modeltype, jacprop = 0; doplot=false)
x,y,u = generate_data(sys, modeltype, seed)
iters = modeltype==System ? 2000 : 1000
if jacprop > 0
iters = round(Int, iters / sqrt(jacprop+1))
n,m = sys.ns,sys.n
R1,R2 = 0.01eye(n*(n+m)), 10eye(n) # Fast
P0 = 10000R1
model = LTVModels.fit_model(LTVModels.KalmanModel, x,u,R1,R2,P0, extend=true)
xt,ut,yt = [x;u], u, y
for i = 1:jacprop
xa = x .+ std(x,2)/10 .* randn(size(x))
ua = u .+ std(u,2)/10 .* randn(size(u))
ya = LTVModels.predict(model, xa,ua)
xt = [xt [xa;ua]]
ut = [ut ua]
yt = modeltype [VelSystem, VelSystemD] ? [yt ya[3:4,:]] : [yt ya]
end
x = [x;u]
else
x = [x;u]
xt,ut,yt = x,u,y
end
xv,yv,uv = generate_data(sys, modeltype, seed+1)
xv = [xv;uv]
activation = relu
m = modeltype(sys, num_params, activation)
opt = [ADAM(params(m), stepsize, decay=0.005); [expdecay(Param(p), wdecay) for p in params(m) if p isa AbstractMatrix]]
results = fit_model(opt, loss(m), m, xt, yt, ut, xv, yv, uv, sys,modeltype, iters=iters ÷ jacprop, doplot=doplot, batch_size = size(yt,2) ÷ jacprop)
@pack results = x, u, y, xv, uv, yv
results = [results]
if doplot
Jm, Js, Jtrue = all_jacobians(results)
fig_jac = plot_jacobians(Jm, Js, Jtrue); #gui();
fig_time = plotresults(results); #gui();
fig_eig = plot_eigvals(results); #gui();
return fig_jac, fig_time, fig_eig
end
results
end
function savefigures(f1,f2,f3,n)
savefig(f1,"/local/home/fredrikb/papers/nn_prior/figs/jacs$(n).pdf")
savefig(f2,"/local/home/fredrikb/papers/nn_prior/figs/timeseries$(n).pdf")
savefig(f3,"/local/home/fredrikb/papers/nn_prior/figs/eigvals$(n).pdf")
end
# Produce figures
# pyplot(reuse=false, show=false)
# figs1 = fit_system(1, sys, SystemD, 4, doplot=false); #savefigures(figs1..., 1)
figs2 = fit_system(1, sys, DiffSystemD, 2, doplot=true); #savefigures(figs2..., 2)
# figs3 = fit_system(1, sys, VelSystemD, 2, doplot=false); #savefigures(figs3..., 3)
# figs4 = fit_system(1, sys, AffineSystem, doplot=true); #savefigures(figs4..., 4)
# error()
##
res = map(1:num_montecarlo) do it
r1 = @spawn fit_system(it, sys, System, 2)
r2 = @spawn fit_system(it, sys, DiffSystem, 2)
r3 = @spawn fit_system(it, sys, SystemD, 2)
r4 = @spawn fit_system(it, sys, DiffSystemD, 2)
println("Done with montecarlo run $it")
r1,r2,r3,r4
end
res = [(fetch.(rs)...) for rs in res]
open(file->serialize(file, res), "res","w") # Save
error()
##
sleep(60*3)
##
using StatPlots
res = open(file->deserialize(file), "res") # Load
sys = res[1][1][1][:sys]
nr = length(res[1]) ÷ 2
labelvec = ["f" "g"]
infostring = @sprintf("Num hidden: %d, sigma: %2.2f, Montecarlo: %d", num_params, sys.σ0, num_montecarlo)
# pgfplots(size=(600,300), show=true)
pred = hcat([eval_pred.(get_res(res,i), true) for i 1:nr]...)
sim = hcat([eval_sim.(get_res(res,i), true) for i 1:nr]...)
predp = hcat([eval_pred.(get_res(res,i), true) for i nr+1:2nr]...)
simp = hcat([eval_sim.(get_res(res,i), true) for i nr+1:2nr]...)
vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Prediction RMS)", reuse=false, c=:red)
violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue)
vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""],xticks=(1:nr,labelvec), ylabel="log(Simulation RMS)", c=:red)
violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""],xticks=(1:nr,labelvec), c=:blue, legend=false)
plot(vio1,vio2,title=infostring); gui()
# savefig2("/local/home/fredrikb/papers/nn_prior/figs/valerrB.tex")
# jacerrtrain = hcat([eval_jac.(get_res(res,i), false) for i ∈ 1:nr]...)
jacerr = hcat([eval_jac.(get_res(res,i), true) for i 1:nr]...)
jacerrp = hcat([eval_jac.(get_res(res,i), true) for i nr+1:2nr]...)
violin(identity.(jacerr), side=:left, xticks=(1:nr,labelvec), title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red)
violin!((1:nr)',identity.(jacerrp), side=:right, xticks=(1:nr,labelvec), ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui()
# savefig2("/local/home/fredrikb/papers/nn_prior/figs/jacerrB.tex")
##
# TODO: ====================================================================================
# Simulation performance
# Prediction performance
# Jacobian error
# The above on both train and test data
# Compare likelihood (weight by confidence)
# Learn time difference
# Learn states directly
# Learn only velocity
# LayerNorm on/off
# Weight decay on/off
# Compare true vs ad-hoc bootstrap
num_workers = 1; addprocs([(@sprintf("philon-%2.2d",i),num_workers) for i in [2:4; 6:10]]);addprocs([(@sprintf("ktesibios-%2.2d",i),num_workers) for i in 1:9]);addprocs([(@sprintf("heron-%2.2d",i),num_workers) for i in [1,2,3,4,6,11]])
vio1 = violin(log10.(pred), side=:left, lab=["Standard" "" ""], ylabel="log(Prediction RMS)", reuse=false, c=:red)
violin!((1:nr)',log10.(predp), side=:right, lab=["Bayesian" "" ""], c=:blue)
vio2 = violin(min.(log10.(sim),2), side=:left, lab=["Standard" "" ""], ylabel="log(Simulation RMS)", c=:red)
violin!((1:nr)',min.(log10.(simp),2), side=:right, lab=["Bayesian" "" ""], c=:blue, legend=false)
plot(vio1,vio2,title=infostring); gui()
jacerr = hcat([eval_jac.(get_res(res,i), true) for i 1:nr]...)
jacerrp = hcat([eval_jac.(get_res(res,i), true) for i nr+1:2nr]...)
violin(identity.(jacerr), side=:left, title="Jacobian error (validation data)"*infostring, ylabel="norm of Jacobian error", reuse=true, lab=["Standard" "" ""], c=:red)
violin!((1:nr)',identity.(jacerrp), side=:right, ylabel="Jacobian RMS", reuse=true, lab=["Bayesian" "" ""], c=:blue); gui()
......@@ -11,7 +11,7 @@ function System(sys, num_params, activation)
@unpack n,ns = sys
ny = ns
np = num_params
m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
System(m, sys)
end
(m::System)(x) = m.m(x)
......@@ -25,7 +25,7 @@ function DiffSystem(sys, num_params, activation)
@unpack n,ns = sys
ny = ns
np = num_params
m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
DiffSystem(m, sys)
end
(m::DiffSystem)(x) = m.m(x)+x[1:m.sys.ns,:]
......@@ -39,7 +39,7 @@ function VelSystem(sys, num_params, activation)
@unpack n,ns = sys
ny = n
np = num_params
m = Chain(Dense(ns+n,np, activation), Dropout(0.5), Dense(np, ny))
m = Chain(Dense(ns+n,np, activation), Dropout(0.01), Dense(np,np, activation), Dropout(0.01), Dense(np, ny))
VelSystem(m, sys)
end
(m::VelSystem)(x) = m.m(x)
......@@ -56,9 +56,9 @@ end
for (S1,S2) in zip([:SystemD, :DiffSystemD, :VelSystemD],[:System, :DiffSystem, :VelSystem])
@eval $(S1)(args...) = $(S1)($(S2)(args...))
# @show S1
@eval StatsBase.predict(ms::Vector{$S1}, x, mc=10) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
@eval simulate(ms::Vector{$S1}, x, mc=10) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc)
@eval Flux.jacobian(ms::Vector{$S1}, x, mc=10) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
@eval StatsBase.predict(ms::Vector{$S1}, x, mc=100) = (mean(i->predict(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[predict(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
@eval simulate(ms::Vector{$S1}, x, mc=100) = mean(i->simulate(getfield.(ms,:s), x, false), 1:mc)
@eval Flux.jacobian(ms::Vector{$S1}, x, mc=100) = (mean(i->jacobian(getfield.(ms,:s), x, false)[1], 1:mc), std(cat(3,[jacobian(getfield.(ms,:s), x, false)[1] for i = 1:mc]...), 3))
@eval (s::$(S1))(x) = s.s(x)
end
......@@ -177,7 +177,8 @@ function Flux.jacobian(ms::EnsembleVelSystem, x, testmode=true)
squeeze(mean(jacmat, 3), 3), squeeze(std(jacmat, 3), 3)
end
models(results) = [r[:m] for r in results]
models(results::AbstractVector) = [r[:m] for r in results]
models(results::Associative) = [r[:m]]
function all_jacobians(results, eval=false)
@unpack x,u,xv,uv,sys,modeltype = results[1]
......@@ -194,6 +195,7 @@ function all_jacobians(results, eval=false)
Jm, Js, Jtrue
end
function plot_jacobians(Jm, Js, Jtrue)
N = size(Jm,1)
colors = [HSV(h,1,0.8) for h in linspace(0,254,N)]
......@@ -271,8 +273,8 @@ function plot_prediction(fig, results, eval=false)
ms = models(results)
yh, bounds = predict(ms, x)
for i = 1:size(yh,1)
if size(bounds[1],1) >= i
plot!(fig, yh[i,:], fillrange = getindex.(bounds,i,:), fillalpha=0.3, subplot=i, lab="Prediction")
if size(bounds,1) >= i
plot!(fig, yh[i,:], ribbon = 2bounds[i,:], fillalpha=0.3, subplot=i, lab="Prediction")
else
plot!(fig, yh[i,:], subplot=i, lab="Prediction")
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment