Skip to content
Snippets Groups Projects
Commit e1d13d3f authored by Fredrik Bagge Carlson's avatar Fredrik Bagge Carlson
Browse files

add multiple shooting

parent 8178dd53
Branches
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ using OrdinaryDiffEq ...@@ -5,7 +5,7 @@ using OrdinaryDiffEq
using ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters, InteractNext using ValueHistories, IterTools, MLDataUtils, OrdinaryDiffEq, Parameters, InteractNext
inoffice() = gethostname() ["billman", "Battostation"] inoffice() = gethostname() ["billman", "Battostation", "nichols"]
inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl")) inoffice() && !isdefined(:time_derivative) && include(Pkg.dir("DynamicMovementPrimitives","src","two_link.jl"))
if !isdefined(:AbstractEnsembleSystem) if !isdefined(:AbstractEnsembleSystem)
inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl") inoffice() || @everywhere include("/var/tmp/fredrikb/v0.6/DynamicMovementPrimitives/src/two_link.jl")
......
using LinearAlgebra, Statistics, Random, Plots, Flux, DSP, Parameters, LTVModelsBase
using Flux: param, params
using Flux.Tracker
@with_kw struct LinearSys
A
B
N = 1000
nx = size(A,1)
nu = size(B,2)
h = 0.02
σ0 = 0
sind = 1:nx
uind = nx+1:(nx+nu)
s1ind = (nx+nu+1):(nx+nu+nx)
end
function LinearSys(seed; nx = 10, nu = nx, h=0.02, kwargs...)
Random.seed!(seed)
A = randn(nx,nx)
A = A-A' # skew-symmetric = pure imaginary eigenvalues
A = A - h*I # Make 'slightly' stable
A = exp(h*A) # discrete time
B = h*randn(nx,nu)
LinearSys(;A=A, B=B, nx=nx, nu=nu, h=h, kwargs...)
end
function generate_data(sys::LinearSys, seed, validation=false)
@unpack A,B,N, nx, nu, h, σ0 = sys
Random.seed!(seed)
u = DSP.filt(ones(5),[5], 10randn(N+2,nu))'
t = h:h:N*h+h
x0 = randn(nx)
x = zeros(nx,N+1)
x[:,1] = x0
for i = 1:N-1
x[:,i+1] = A*x[:,i] + B*u[:,i]
end
validation || (x .+= σ0 * randn(size(x)))
u = u[:,1:N]
@assert all(isfinite, u)
x,u
end
function true_jacobian(sys::LinearSys, x, u)
[sys.A sys.B]
end
wdecay = 0
stepsize = 0.02
sys = LinearSys(1, nx=5, N=400, h=0.02, σ0 = 0.01)
true_jacobian(x,u) = true_jacobian(sys,x,u)
const nu = sys.nu
const nx = sys.nx
trajs = [Trajectory(generate_data(sys, i, true)...) for i = 1:3]
const t = trajs[1]
const y = [t.x[:,i] for i = 1:length(t)]
const u = [t.u[:,i] for i = 1:length(t)]
push!(y,y[end])
push!(u,u[end])
const num_shooting = 4
const length_shooting = sys.N ÷ num_shooting
@assert length_shooting*num_shooting == sys.N
num_params = 100
const f = Chain(Dense(nx+nu, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, num_params, swish), Dense(num_params, nx))
function model(x,u)
xp = f([x;u]) .+ 0.5x
xp,xp
end
const λ = [zeros(nx) for _ = 1:num_shooting-1]
const initial_x = [param(t.x[:,(i-1)*length_shooting+1]) for i = 1:num_shooting]
function simulate(x0, u)
rnn = Flux.Recur(model,x0,u[1])
rnn.(u)
end
i2inds(i) = (i-1)*length_shooting+1:i*length_shooting
i2inds1(i) = (i-1)*length_shooting+1:i*length_shooting+1
function simulate_shooting(u)
map(1:num_shooting) do i
simulate(initial_x[i],u[i2inds1(i)])
end
end
function constraints(s=simulate_shooting(u))
[initial_x[i+1] - s[i][end] for i = 1:num_shooting-1]
end
function augmented_lagrangean(μ)
s = simulate_shooting(u)
c = constraints(s) # nx × num_shooting
l = -sum(sum(λ.*c) for (λ,c) in zip(λ,c)) # Lagrange multiplier
l += 10sum(norm, params(f))
cv = sum(sum.(abs2,c)) # Augmentation
l += μ/2*cv
for (i,sim) in enumerate(s)
l += sum(sum(x->abs2.(x),sim.-y[i2inds1(i)]))/length_shooting # cost
end
l,c,cv
end
μ = 1.
augmented_lagrangean(μ)
# opt = Momentum([params(f); initial_x], 0.000001, decay = 0.001, ρ = 0.8)
opt = ADAM([params(f); initial_x], 0.001, decay = 0.001, ϵ = 0.1)
function train(opt,μ)
cvtol = 5.
for iter = 1:1000
loss,c,cv = augmented_lagrangean(μ)
back!(loss)
opt()
if iter % 10 == 0
println(iter, " ", Flux.data(loss))
plot(hcat(Flux.data.(simulate(t.x[:,1], u))...)', reuse=true)
plot!(t.x') |> gui
end
if Flux.data(cv) < cvtol
println("Updating λ")
cvtol /= 2
for i in eachindex(λ)
λ[i] .-= μ.*Flux.data(c[i])
end
else
μ *= 1.005
end
end
μ
end
μ = train(opt,μ)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment