Select Git revision
launch_controller.launch
learn_variance.jl 2.29 KiB
mutable struct ADAMOptimizer{T, VecType <: AbstractArray}
Θ::VecType
α::T
β1::T
β2::T
ɛ::T
m::VecType
v::VecType
end
ADAMOptimizer{T,VecType <: AbstractArray}(Θ::VecType; α::T = 0.005, β1::T = 0.9, β2::T = 0.999, ɛ::T = 1e-8, m=zeros(Θ), v=zeros(Θ)) = ADAMOptimizer{T,VecType}(Θ, α, β1, β2, ɛ, m, v)
function (a::ADAMOptimizer)(g, t::Integer)
mul = (1-a.β1)
mul2 = (1-a.β2)
div = 1/(1 - a.β1 ^ t)
div2 = 1/(1 - a.β2 ^ t)
α,β1,β2,ɛ,m,v,Θ = a.α,a.β1,a.β2,a.ɛ,a.m,a.v,a.Θ
Base.Threads.@threads for i = 1:length(g)
@inbounds m[i] = β1 * m[i] + mul * g[i]
@inbounds v[i] = β2 * v[i] + mul2 * g[i]^2
@inbounds Θ[i] -= α * m[i] * div / (sqrt(v[i] * div2) + ɛ)
end
Θ
end
using ReverseDiff: GradientTape, gradient!
using Juno
##
N = 1000
x_true = [1., 2.]
σ_true = 0.1
A = [linspace(0,1,N) ones(N)]
y_true = A*x_true
y = y_true + σ_true*randn(N)
##
predict(x) = A*x[1:2]
lossfun(x) = sum((predict(x).-y).^2)/(2x[3]) + N/2*log(x[3])
x = [randn(size(x_true)); 1.]
inputs = (x,)
loss_tape = GradientTape(lossfun, inputs)
results = similar.(inputs)
all_results = DiffBase.GradientResult.(results)
opt = ADAMOptimizer(x; α = 0.002, β1 = 0.9, β2 = 0.999, ɛ = 1e-8)
@progress for iter = 1:1000
gradient!(all_results, loss_tape, inputs)
iter % 100 == 0 && println("Cost: ", all_results[1].value)
opt(all_results[1].derivs[1], iter)
end
@show x
@show x_true
@show norm(x[1:2]-x_true)/norm(x_true)
@show std(y-A*x[1:2])
@show √(x[3])
##
##
N = 1000
x_true = [1.5, 2.5]
σ_true = 0.15
A = linspace(0,1,N)
y_true = A*x_true[1] .+ x_true[2]
y = y_true + σ_true*randn(N)
##
using Flux
m = @Chain( Input(1), Affine(2))
data = zip(A,y)
lossfun6(ŷ,y) = mse(ŷ[1],y)*(2ŷ[2]) - N/2*log(ŷ[2])
function Flux.back!(::typeof(lossfun6), Δ, ŷ, y)
[(Flux.back!(Flux.mse, Δ, Float64[ŷ[1]], y).*(2ŷ[2])) (Flux.mse(ŷ[1],y).*2 .- Δ*N/2/ŷ[2])]
end
model = mxnet(m)
Flux.train!(model, data, loss=lossfun6, epoch=20, η = 0.0001)
x = [m[2].W.x; m[2].b.x]
x[:,2] = sqrt.(1./x[:,2])
@show x
@show x_true
@show norm(x[1:2]-x_true)/norm(x_true)
@show std(y-m(y))
@show √(x[3])