Skip to content
Snippets Groups Projects
Select Git revision
  • f75f403ec01f73479ac451919fe61e17f2b6cb66
  • master default protected
  • sub1
3 results

gridworld.jl

Blame
  • gridworld.jl 1.33 KiB
    using Plots
    function gridworld()
        grid_size = (10,10)
        iters = 10
        A = [1,2,3,4]
        running_cost = 0.1
        grid = zeros(grid_size)
        Q = 0*ones(length(A),grid_size...)
        goal = (5,5)
        grid[goal...] = 5
    
        function transition(x,y,a)
            if a == 1
                y = min(y+1, grid_size[2])
            elseif a == 2
                x = min(x+1, grid_size[1])
            elseif a == 3
                y = max(y-1, 1)
            else
                x = max(x-1, 1)
            end
            x,y
        end
    
        reward(x,y,a) = -running_cost #+ grid[x,y]
    
        for i = 1:iters
            for y = 1:grid_size[2]
                for x = 1:grid_size[1]
                    if (x,y) == goal
                        Q[:,x,y] = grid[goal...]
                        continue
                    end
                    for a = A
                        xn,yn = transition(x,y,a)
                        r = reward(xn,yn,a)
                        Q[a,x,y] = r + maximum(Q[:,xn,yn])
                    end
                end
            end
            heatmap(squeeze(maximum(Q,1),1)); gui()
            # println(i)
            # sleep(0.01)
        end
        println("done")
        Q
    end
    @time Q = gridworld()
    
    for y = 1:grid_size[2]
        for x = 1:grid_size[1]
            a = indmax(Q[:,x,y])
            sym = a == 1 ? :utriangle : a == 2 ? :rtriangle : a == 3 ? :dtriangle : :ltriangle
            scatter!([x],[y],m=(sym,5), lab="", legend=false)
        end
    end
    gui()