Select Git revision
gridworld.jl

Fredrik Bagge Carlson authored
gridworld.jl 1.33 KiB
using Plots
function gridworld()
grid_size = (10,10)
iters = 10
A = [1,2,3,4]
running_cost = 0.1
grid = zeros(grid_size)
Q = 0*ones(length(A),grid_size...)
goal = (5,5)
grid[goal...] = 5
function transition(x,y,a)
if a == 1
y = min(y+1, grid_size[2])
elseif a == 2
x = min(x+1, grid_size[1])
elseif a == 3
y = max(y-1, 1)
else
x = max(x-1, 1)
end
x,y
end
reward(x,y,a) = -running_cost #+ grid[x,y]
for i = 1:iters
for y = 1:grid_size[2]
for x = 1:grid_size[1]
if (x,y) == goal
Q[:,x,y] = grid[goal...]
continue
end
for a = A
xn,yn = transition(x,y,a)
r = reward(xn,yn,a)
Q[a,x,y] = r + maximum(Q[:,xn,yn])
end
end
end
heatmap(squeeze(maximum(Q,1),1)); gui()
# println(i)
# sleep(0.01)
end
println("done")
Q
end
@time Q = gridworld()
for y = 1:grid_size[2]
for x = 1:grid_size[1]
a = indmax(Q[:,x,y])
sym = a == 1 ? :utriangle : a == 2 ? :rtriangle : a == 3 ? :dtriangle : :ltriangle
scatter!([x],[y],m=(sym,5), lab="", legend=false)
end
end
gui()