Select Git revision
bayesian_dropout.jl

Fredrik Bagge Carlson authored
bayesian_dropout.jl 1.72 KiB
using Flux, IterTools, ValueHistories, MLDataUtils
function update_plot(p; max_history = 10, attribute = :markercolor)
num_series = length(p.series_list)
if num_series > 1
if num_series > max_history
deleteat!(p.series_list,1:num_series-max_history)
end
for i = 1:min(max_history, length(p.series_list))-1
alpha = i/max_history
c = p[i][attribute]
c = RGBA(
alpha*c.r + (1-alpha)*0.5,
alpha*c.g + (1-alpha)*0.5,
alpha*c.b + (1-alpha)*0.5,
c.alpha)
p[i][attribute] = c
end
end
end
iters = 4000
N = 30
n = 50
function train(m = Chain(Dense(1,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,1)))
x = linspace(-10,10,N)'
y = sin.(x)./(x)
trace = History(Float64)
batcher = batchview(shuffleobs((x,y)), N)
dataset = ncycle(batcher, iters)
opt = ADAM(params(m), 0.01, decay = 0)
loss(x,y) = sum((m(x).-y).^2)/N
fig = plot(layout=2, reuse=true)
function cb()
push!(trace, loss(x,y).data[])
plot!(vec(x), [vec(y) m(x).data'], ylims=[-0.3,1], c = [:blue :red], subplot=1)
plot!(trace, subplot=2, c=:blue, yscale=:log10)
update_plot(fig[1], max_history = 10, attribute=:linecolor)
update_plot(fig[2], max_history=1)
gui()
end
Flux.train!(loss, dataset, opt, cb=cb)
m, trace, x, y
end
m, trace, x, y= train()
##
mc = 1000
yh = hcat([m(x).data[:] for i=1:mc÷100]...)
yhm = mean(i->m(x).data[:], 1:mc)
yhs = std(hcat([m(x).data[:] for i = 1:mc]...), 2)
closeall()
plot(vec(x),vec(y),lab="y", linewidth=3)
plot!(vec(x),yhm,lab="yhm", ribbon=2yhs, linewidth=3, alpha=0.2)
plot!(vec(x),yh,lab="yh")
gui()
##