Skip to content
Snippets Groups Projects
Commit e5b5af41 authored by Fredrik Bagge Carlson's avatar Fredrik Bagge Carlson
Browse files

Merge branch 'master' of gitlab.control.lth.se:cont-frb/reinforcementlearning

parents b6f79058 dd85dd91
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ using OrdinaryDiffEq ...@@ -5,7 +5,7 @@ using OrdinaryDiffEq
using ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters, InteractNext using ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters, InteractNext
inoffice() = gethostname() ["billman", "Battostation"] inoffice() = gethostname() ["billman", "Battostation", "nichols"]
inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl")) inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl"))
if !isdefined(:AbstractEnsembleSystem) if !isdefined(:AbstractEnsembleSystem)
inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl") inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl")
......
using LinearAlgebra, Statistics, Random, Plots, Flux, DSP, Parameters, LTVModelsBase
using Flux: param, params
using Flux.Tracker
@with_kw struct LinearSys
A
B
N = 1000
nx = size(A,1)
nu = size(B,2)
h = 0.02
σ0 = 0
sind = 1:nx
uind = nx+1:(nx+nu)
s1ind = (nx+nu+1):(nx+nu+nx)
end
function LinearSys(seed; nx = 10, nu = nx, h=0.02, kwargs...)
Random.seed!(seed)
A = randn(nx,nx)
A = A-A' # skew-symmetric = pure imaginary eigenvalues
A = A - h*I # Make 'slightly' stable
A = exp(h*A) # discrete time
B = h*randn(nx,nu)
LinearSys(;A=A, B=B, nx=nx, nu=nu, h=h, kwargs...)
end
function generate_data(sys::LinearSys, seed, validation=false)
@unpack A,B,N, nx, nu, h, σ0 = sys
Random.seed!(seed)
u = DSP.filt(ones(5),[5], 10randn(N+2,nu))'
t = h:h:N*h+h
x0 = randn(nx)
x = zeros(nx,N+1)
x[:,1] = x0
for i = 1:N-1
x[:,i+1] = A*x[:,i] + B*u[:,i]
end
validation || (x .+= σ0 * randn(size(x)))
u = u[:,1:N]
@assert all(isfinite, u)
x,u
end
function true_jacobian(sys::LinearSys, x, u)
[sys.A sys.B]
end
wdecay = 0
stepsize = 0.02
sys = LinearSys(1, nx=5, N=400, h=0.02, σ0 = 0.01)
true_jacobian(x,u) = true_jacobian(sys,x,u)
const nu = sys.nu
const nx = sys.nx
trajs = [Trajectory(generate_data(sys, i, true)...) for i = 1:3]
const t = trajs[1]
const y = [t.x[:,i] for i = 1:length(t)]
const u = [t.u[:,i] for i = 1:length(t)]
push!(y,y[end])
push!(u,u[end])
const num_shooting = 4
const length_shooting = sys.N ÷ num_shooting
@assert length_shooting*num_shooting == sys.N
num_params = 100
const f = Chain(Dense(nx+nu, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, nx))
function model(x,u)
xp = f([x;u]) .+ 0.5x
xp,xp
end
const λ = [zeros(nx) for _ = 1:num_shooting-1]
const initial_x = [param(t.x[:,(i-1)*length_shooting+1]) for i = 1:num_shooting]
function simulate(x0, u)
rnn = Flux.Recur(model,x0,u[1])
rnn.(u)
end
i2inds(i) = (i-1)*length_shooting+1:i*length_shooting
i2inds1(i) = (i-1)*length_shooting+1:i*length_shooting+1
function simulate_shooting(u)
map(1:num_shooting) do i
simulate(initial_x[i],u[i2inds1(i)])
end
end
function constraints(s=simulate_shooting(u))
[initial_x[i+1] - s[i][end] for i = 1:num_shooting-1]
end
function augmented_lagrangean(μ)
s = simulate_shooting(u)
c = constraints(s) # nx × num_shooting
l = -sum(sum(λ.*c) for (λ,c) in zip(λ,c)) # Lagrange multiplier
l += 10sum(norm, params(f))
cv = sum(sum.(abs2,c)) # Augmentation
l += μ/2*cv
for (i,sim) in enumerate(s)
l += sum(sum(x->abs2.(x),sim.-y[i2inds1(i)]))/length_shooting # cost
end
l,c,cv
end
μ = 1.
augmented_lagrangean(μ)
# opt = Momentum([params(f); initial_x], 0.000001, decay = 0.001, ρ = 0.8)
opt = ADAM([params(f); initial_x], 0.001, decay = 0.001, ϵ = 0.1)
function train(opt,μ)
cvtol = 5.
for iter = 1:1000
loss,c,cv = augmented_lagrangean(μ)
back!(loss)
opt()
if iter % 10 == 0
println(iter, " ", Flux.data(loss))
plot(hcat(Flux.data.(simulate(t.x[:,1], u))...)', reuse=true)
plot!(t.x') |> gui
end
if Flux.data(cv) < cvtol
println("Updating λ")
cvtol /= 2
for i in eachindex(λ)
λ[i] .-= μ.*Flux.data(c[i])
end
else
μ *= 1.005
end
end
μ
end
μ = train(opt,μ)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment