Skip to content
Snippets Groups Projects
Select Git revision
  • 9c1b96ceb094d9198b6c3f4626383f5284900d4f
  • master default protected
  • sub1
3 results

bayesian_dropout.jl

Blame
  • bayesian_dropout.jl 1.72 KiB
    using Flux, IterTools, ValueHistories, MLDataUtils
    function update_plot(p; max_history = 10, attribute = :markercolor)
        num_series = length(p.series_list)
        if num_series > 1
            if num_series > max_history
                deleteat!(p.series_list,1:num_series-max_history)
            end
    
            for i = 1:min(max_history, length(p.series_list))-1
                alpha = i/max_history
                c = p[i][attribute]
                c = RGBA(
                alpha*c.r + (1-alpha)*0.5,
                alpha*c.g + (1-alpha)*0.5,
                alpha*c.b + (1-alpha)*0.5,
                c.alpha)
                p[i][attribute] = c
            end
        end
    
    end
    
    iters = 4000
    N = 30
    n = 50
    
    function train(m = Chain(Dense(1,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,n,relu), Dropout(0.01), Dense(n,1)))
    x = linspace(-10,10,N)'
    y = sin.(x)./(x)
    trace = History(Float64)
    batcher = batchview(shuffleobs((x,y)), N)
    dataset = ncycle(batcher, iters)
    opt = ADAM(params(m), 0.01, decay = 0)
    loss(x,y) = sum((m(x).-y).^2)/N
    fig = plot(layout=2, reuse=true)
    function cb()
        push!(trace, loss(x,y).data[])
        plot!(vec(x), [vec(y) m(x).data'], ylims=[-0.3,1], c = [:blue :red], subplot=1)
        plot!(trace, subplot=2, c=:blue, yscale=:log10)
        update_plot(fig[1], max_history = 10, attribute=:linecolor)
        update_plot(fig[2], max_history=1)
        gui()
    end
    Flux.train!(loss, dataset, opt, cb=cb)
    
    m, trace, x, y
    end
    m, trace, x, y= train()
    
    ##
    mc = 1000
    yh = hcat([m(x).data[:] for i=1:mc÷100]...)
    yhm = mean(i->m(x).data[:], 1:mc)
    yhs = std(hcat([m(x).data[:] for i = 1:mc]...), 2)
    
    closeall()
    plot(vec(x),vec(y),lab="y", linewidth=3)
    plot!(vec(x),yhm,lab="yhm", ribbon=2yhs, linewidth=3, alpha=0.2)
    plot!(vec(x),yh,lab="yh")
    
    gui()
    ##