diff --git a/flux/nn_prior/linear_sys.jl b/flux/nn_prior/linear_sys.jl index b8fda731f813f2f64dc5ed421a689d058ef83ec8..367a3a97c75155ce0451a992984043ef9cd061fa 100644 --- a/flux/nn_prior/linear_sys.jl +++ b/flux/nn_prior/linear_sys.jl @@ -2,36 +2,27 @@ import Parameters -struct LinearSys <: AbstractSystem +@with_kw struct LinearSys <: AbstractSystem A B - N - n - ns - h - σ0 - sind - uind - s1ind + N = 1000 + n = 10 + ns = n + h = 0.02 + σ0 = 0 + sind = 1:ns + uind = ns+1:(ns+n) + s1ind = (ns+n+1):(ns+n+ns) end -function LinearSys(seed; - N = 1000, - n = 10, - ns = n, - h = 0.02, - σ0 = 0, - sind = 1:ns, - uind = ns+1:(ns+n), - s1ind = (ns+n+1):(ns+n+ns)) - +function LinearSys(seed; kwargs...) srand(seed) A = randn(n,n) A = A-A' # skew-symmetric = pure imaginary eigenvalues A = A - h*I # Make 'slightly' stable A = expm(h*A) # discrete time B = h*randn(n,n) - LinearSys(A,B,N, n, ns, h, σ0, sind, uind, s1ind) + LinearSys(A=A, B=B, kwargs...) end function generate_data(sys::LinearSys, systype, seed, validation=false) diff --git a/gridworld.jl b/gridworld.jl index d4344259b36900c41f5225483c59f3ca41b808ea..243e1c4c159112db0bf1b5a9d177370d3e29a03b 100644 --- a/gridworld.jl +++ b/gridworld.jl @@ -1,12 +1,13 @@ using Plots -function gridworld() - grid_size = (10,10) - iters = 10 - A = [1,2,3,4] - running_cost = 0.1 - grid = zeros(grid_size) - Q = 0*ones(length(A),grid_size...) - goal = (5,5) +function gridworld(; + grid_size = (10,10), + iters = 100, + A = [1,2,3,4], + running_cost = 0.1, + grid = zeros(grid_size), + Q = 0*ones(length(A),grid_size...), + goal = (5,5)) + grid[goal...] = 5 function transition(x,y,a) @@ -28,7 +29,7 @@ function gridworld() for y = 1:grid_size[2] for x = 1:grid_size[1] if (x,y) == goal - Q[:,x,y] = grid[goal...] + Q[:,x,y] .= grid[goal...] continue end for a = A @@ -38,17 +39,17 @@ function gridworld() end end end - heatmap(squeeze(maximum(Q,1),1)); gui() + heatmap(dropdims(maximum(Q,dims=1),dims=1)); gui() # println(i) # sleep(0.01) end println("done") Q end -@time Q = gridworld() +@time Q = gridworld(); -for y = 1:grid_size[2] - for x = 1:grid_size[1] +for y = 1:10#grid_size[2] + for x = 1:10#grid_size[1] a = indmax(Q[:,x,y]) sym = a == 1 ? :utriangle : a == 2 ? :rtriangle : a == 3 ? :dtriangle : :ltriangle scatter!([x],[y],m=(sym,5), lab="", legend=false)