diff --git a/flux/nn_prior/linear_sys.jl b/flux/nn_prior/linear_sys.jl
index b8fda731f813f2f64dc5ed421a689d058ef83ec8..367a3a97c75155ce0451a992984043ef9cd061fa 100644
--- a/flux/nn_prior/linear_sys.jl
+++ b/flux/nn_prior/linear_sys.jl
@@ -2,36 +2,27 @@
 
 import Parameters
 
-struct LinearSys <: AbstractSystem
+@with_kw struct LinearSys <: AbstractSystem
     A
     B
-    N
-    n
-    ns
-    h
-    σ0
-    sind
-    uind
-    s1ind
+    N = 1000
+    n = 10
+    ns = n
+    h = 0.02
+    σ0 = 0
+    sind = 1:ns
+    uind = ns+1:(ns+n)
+    s1ind = (ns+n+1):(ns+n+ns)
 end
 
-function LinearSys(seed;
-    N     = 1000,
-    n     = 10,
-    ns    = n,
-    h     = 0.02,
-    σ0    = 0,
-    sind  = 1:ns,
-    uind  = ns+1:(ns+n),
-    s1ind = (ns+n+1):(ns+n+ns))
-
+function LinearSys(seed; kwargs...)
     srand(seed)
     A = randn(n,n)
     A = A-A'        # skew-symmetric = pure imaginary eigenvalues
     A = A - h*I     # Make 'slightly' stable
     A = expm(h*A)   # discrete time
     B = h*randn(n,n)
-    LinearSys(A,B,N, n, ns, h, σ0, sind, uind, s1ind)
+    LinearSys(A=A, B=B, kwargs...)
 end
 
 function generate_data(sys::LinearSys, systype, seed, validation=false)
diff --git a/gridworld.jl b/gridworld.jl
index d4344259b36900c41f5225483c59f3ca41b808ea..243e1c4c159112db0bf1b5a9d177370d3e29a03b 100644
--- a/gridworld.jl
+++ b/gridworld.jl
@@ -1,12 +1,13 @@
 using Plots
-function gridworld()
-    grid_size = (10,10)
-    iters = 10
-    A = [1,2,3,4]
-    running_cost = 0.1
-    grid = zeros(grid_size)
-    Q = 0*ones(length(A),grid_size...)
-    goal = (5,5)
+function gridworld(;
+    grid_size     = (10,10),
+    iters         = 100,
+    A             = [1,2,3,4],
+    running_cost  = 0.1,
+    grid          = zeros(grid_size),
+    Q             = 0*ones(length(A),grid_size...),
+    goal          = (5,5))
+    
     grid[goal...] = 5
 
     function transition(x,y,a)
@@ -28,7 +29,7 @@ function gridworld()
         for y = 1:grid_size[2]
             for x = 1:grid_size[1]
                 if (x,y) == goal
-                    Q[:,x,y] = grid[goal...]
+                    Q[:,x,y] .= grid[goal...]
                     continue
                 end
                 for a = A
@@ -38,17 +39,17 @@ function gridworld()
                 end
             end
         end
-        heatmap(squeeze(maximum(Q,1),1)); gui()
+        heatmap(dropdims(maximum(Q,dims=1),dims=1)); gui()
         # println(i)
         # sleep(0.01)
     end
     println("done")
     Q
 end
-@time Q = gridworld()
+@time Q = gridworld();
 
-for y = 1:grid_size[2]
-    for x = 1:grid_size[1]
+for y = 1:10#grid_size[2]
+    for x = 1:10#grid_size[1]
         a = indmax(Q[:,x,y])
         sym = a == 1 ? :utriangle : a == 2 ? :rtriangle : a == 3 ? :dtriangle : :ltriangle
         scatter!([x],[y],m=(sym,5), lab="", legend=false)