Skip to content
Snippets Groups Projects
Select Git revision
  • 3302ce89689e5ef0956ef43e4406a5d042d87c92
  • master default protected
  • sub1
3 results

shared_layers.jl

Blame
  • shared_layers.jl 9.58 KiB
    gr()
    using DeterministicPolicyGradient
    using OpenAIGym
    using TensorFlow
    using ValueHistories
    issmall(x) = abs(x) < 1e-10
    r3(x) = round(x,3)
    default(size=(1500,1000), show=false, colorbar=false)
    
    ## Setup environment============================================================
    const env = OpenAIGym.GymEnv("MountainCarContinuous-v0")
    const T     = 500
    const time  = 1:T
    
    # Setup DPG ====================================================================
    const num_actions = 1
    const num_states = length(env.state)
    typealias ACtype Float32
    
    
    function get_apply_policy_gradient(μvars, q, 𝛍, a_)
      ∇aQ         = TensorFlow.gradients(q,[a_])[1] # batch_size x num_actions
      ∇μ          = TensorFlow.gradients(𝛍,μvars) # num_tensors (julia) (original tensor size)
      policy_grad = [reduce_sum(∇aQ'[i] .* reshape(∇μ,(-1,1)),reduction_indices=[2]) for ∇μ in ∇μ, i in 1:num_actions] # Every element of julia array is length_param x batch_size
      # policy_grad is now length(μvars)xnum_actions
      function sizetuple(x)
        dims = get_shape(x).dims
        ([dims[k].value for k = 1:length(dims)]...)
      end
      for i in 1:length(μvars), j in 1:num_actions
        dims =
        policy_grad[i,j] = reshape(policy_grad[i,j],sizetuple(μvars[i]))
      end
      policy_grad = policy_grad[:,1]# + policy_grad[:,2] + policy_grad[:,3] # reducedim does not work
      grads_and_vars = zip(policy_grad, μvars) |> collect
      apply_policy_gradient = train.apply_gradients(μoptimizer,grads_and_vars)
    end
    
    σβ                      = 0.02 # 0.01
    αw                      = 0.0001 # 0.02
    αΘ                      = 0.00001 #0.005
    
    # I changed to squared exponential activation function in the last layer to remain bounded
    ## Initialize solver options ===================================================
    eval_interval           = 10
    dpgopts = DPGopts(num_actions,
    γ                       = 0.999,
    τ                       = 0.02,
    iters                   = 1000,
    stepmod_interval        = 10,
    stepmod_factor          = 0.999,
    hold_actor              = 300, #TODO; deactivated actor update
    experience_replay       = 40T,
    experience_ratio        = 10,
    eval_interval           = eval_interval,
    divergence_threshold    = 1.8,
    mc_eval                 = false,
    mc_rollout              = false,
    mc_amplification        = 5,
    mc_threshold            = Inf,
    pessimistic             = false,
    pessimism               = -10.)
    
    # Initialize functions =========================================================
    
    crelu(x) = concat(2,[x,-x]) |> nn.relu
    se(x) = exp(-x.^2)
    
    # Define network
    session = Session(Graph())
    a_ = placeholder(ACtype, shape=[-1,num_actions])
    s_ = placeholder(ACtype, shape=[-1,num_states])
    y_ = placeholder(ACtype, shape=[-1,1])
    
    seed = rand(1:1000)
    srand(seed)
    W1  = Variable(0.002*rand(ACtype, num_states, 40)-0.001, name="weights1")
    W2  = Variable(0.002*rand(ACtype, 40, 30)-0.001, name="weights2")
    W2a = Variable(0.002*rand(ACtype, num_actions, 30)-0.001, name="weights2a")
    W2μ = Variable(0.002*rand(ACtype, 40, 30)-0.001, name="weights2mu")
    W3  = Variable(0.002*rand(ACtype, 2*30, 1)-0.001, name="weights3")
    W3μ = Variable(0.002*rand(ACtype, 2*30, num_actions)-0.001, name="weights3mu")
    
    B1  = Variable(0.002*rand(ACtype,40)-0.001, name="bias1")
    B2  = Variable(0.002*rand(ACtype,30)-0.001, name="bias2")
    B2μ = Variable(0.002*rand(ACtype,30)-0.001, name="bias2mu")
    B3  = Variable(0*ones(ACtype,1), name="bias3")
    B3μ = Variable(0*ones(ACtype,num_actions), name="bias3mu")
    
    l1  =          s_*W1 + B1  |> nn.tanh
    l2  = l1*W2 + a_*W2a + B2  |> crelu
    q   =         l2*W3  + B3
    l2μ =         l1*W2μ + B2μ |> crelu
    𝛍   =        l2μ*W3μ + B3μ |> nn.tanh # To get action ∈ [-1,1]
    
    srand(seed)
    W1t  = Variable(0.002*rand(ACtype, num_states, 40)-0.001, name="weights1t", trainable=false)
    W2t  = Variable(0.002*rand(ACtype, 40, 30)-0.001, name="weights2t", trainable=false)
    W2at = Variable(0.002*rand(ACtype, num_actions, 30)-0.001, name="weights2at", trainable=false)
    W2μt = Variable(0.002*rand(ACtype, 40, 30)-0.001, name="weights2mut", trainable=false)
    W3t  = Variable(0.002*rand(ACtype, 2*30, 1)-0.001, name="weights3t", trainable=false)
    W3μt = Variable(0.002*rand(ACtype, 2*30, num_actions)-0.001, name="weights3mut", trainable=false)
    
    B1t  = Variable(0.002*rand(ACtype,40)-0.001, name="bias1t", trainable=false)
    B2t  = Variable(0.002*rand(ACtype,30)-0.001, name="bias2t", trainable=false)
    B2μt = Variable(0.002*rand(ACtype,30)-0.001, name="bias2mut", trainable=false)
    B3t  = Variable(0*ones(ACtype,1), name="bias3t", trainable=false)
    B3μt = Variable(0*ones(ACtype,num_actions), name="bias3mut", trainable=false)
    
    # These lines define the network structure
    l1t  = s_*W1t           +  B1t |> nn.tanh
    l2t  = l1t*W2t + a_*W2at + B2t |> crelu # Include actions only in second layer
    qt   =         l2t*W3t  +  B3t
    l2μt =         l1t*W2μt + B2μt |> crelu
    𝛍t   =        l2μt*W3μt + B3μt |> nn.tanh
    
    Q(s1::Vector,a1::Vector,t)  = run(session, q,  Dict(s_ => s1', a_ => a1'))[1]
    Qt(s1::Vector,a1::Vector,t) = run(session, qt, Dict(s_ => s1', a_ => a1'))[1]
    
    μ(s,t)  = run(session, 𝛍, Dict(s_  => s'))[:]
    μt(s,t) = run(session, 𝛍t, Dict(s_ => s'))[:]
    mn = MarkovNoise([2,0,-2], [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9]',2)
    β(s)    = clamp(μ(s,0) + σβ*randn() + rand(mn), -1,1)
    
    type CartpolePolicy <: AbstractPolicy end
    Reinforce.action(policy::CartpolePolicy, r, s, A) = μ(s,0)
    type CartpoleExpPolicy <: AbstractPolicy end
    Reinforce.action(policy::CartpoleExpPolicy, r, s, A) = β(s)
    const policy = CartpolePolicy()
    const exppolicy = CartpoleExpPolicy()
    
    μvars       = [W2μ,W3μ,B2μ,B3μ] # TODO: removed shared layers from policy update
    Qvars       = [W1,W2,W2a,W3,B1,B2,B3]
    
    μoptimizer  = train.GradientDescentOptimizer(αΘ, name="policy_optimizer")
    apply_policy_gradient = get_apply_policy_gradient(μvars, q, 𝛍, a_)
    # TODO: check sign of gradient
    
    sum_square = reduce_mean((y_ - q).^2)
    weight_decay = 0.001*(reduce_sum(W1.^2) + reduce_sum(W2.^2) + reduce_sum(W2a.^2) +  reduce_sum(W3.^2)) #+ reduce_sum(l1.^2)
    # reduce_sum(W2μ.^2) +   reduce_sum(W3μ.^2)
    loss = sum_square + weight_decay
    # train_step = train.minimize(train.GradientDescentOptimizer(αw), loss)
    Qoptimizer = train.GradientDescentOptimizer(αw, name="Q_optimizer")
    # gvs = train.compute_gradients(Qoptimizer,loss)
    # capped_gvs = [(clip_by_value(grad, ACtype(-0.001), ACtype(0.001)), var) for (grad, var) in gvs]
    # train_step = train.apply_gradients(Qoptimizer,capped_gvs)
    train_step = train.minimize(Qoptimizer, loss, var_list=Qvars)
    
    const value_history = QHistory(ACtype)
    # push!(value_history,ACtype(1.))
    
    const τ = dpgopts.τ
    assign_op = [
    assign(W1t,τ*W1   + (1-τ)*W1t),
    assign(W2t,τ*W2   + (1-τ)*W2t),
    assign(W2at,τ*W2a + (1-τ)*W2at),
    assign(W2μt,τ*W2μ + (1-τ)*W2μt),
    assign(W3t,τ*W3   + (1-τ)*W3t),
    assign(B1t,τ*B1   + (1-τ)*B1t),
    assign(B2t,τ*B2   + (1-τ)*B2t),
    assign(B3t,τ*B3   + (1-τ)*B3t),
    assign(W3μt,τ*W3μ + (1-τ)*W3μt),
    assign(B3μt,τ*B3μ + (1-τ)*B3μt)]
    
    function train_critic(s,a,targets)
      l = size(s,1)
      ss,_ = run(session, [sum_square, train_step], Dict(s_ => s, a_ => a, y_ => reshape(targets,l,1)))
      push!(value_history,ss)
    end
    
    function train_actor(s,a,targets)
      run(session, apply_policy_gradient, Dict(s_ => s, a_ => a))
    end
    
    function update_tracking_networks()
      τ  = dpgopts.τ
      run(session,assign_op)
    end
    
    
    const big_states = Matrix{Float64}(T,num_states)
    const big_states2 = Matrix{Float64}(T,num_states)
    const big_actions = Matrix{ACtype}(T,num_actions)
    const big_rewards = Vector{Float64}(T)
    function simulate(noise = false)
      step = 0
      ep   = Episode(env, noise ? exppolicy : policy)
      for sars in ep
        # if !noise; OpenAIGym.render(env); end
        step += 1
        big_states[step,:]  = sars[1]::Vector{Float64}
        big_actions[step,:] = sars[2]::Vector{ACtype}
        big_rewards[step]   = sars[3]::Float64
        big_states2[step,:] = sars[4]::Vector{Float64}
        if step == T; break; end
      end
      return view(big_states,1:step,:), view(big_actions,1:step,:), view(big_rewards,1:step), view(big_states2,1:step,:)
    end
    
    function fit_model(args...)
    end
    
    funs = DPGfuns(μ,μt,Q,Qt, simulate, train_critic, train_actor, update_tracking_networks, fit_model)
    
    pfig = plot(layout=7);
    push!(value_history,0,Float32(2))
    function progressplot(i,s,u, reward, rollout)
      L1,L2 = run(session,[l1,l2], Dict(s_ => collect(s), a_ => collect(u)))
      plot!([0.1s u], lab=["It: $i r: $(reward[i] |> r3)" "" ""], c=[:blue :red :orange], subplot=1)
      plot!(s[:,1],s[:,2], lab="", c=[:blue :red], subplot=2)
      Qplot!(rollout, Q, dpgopts.γ, subplot=3, c=:blue)
      plot!(value_history, subplot=4, c=:blue, markersize=1, yscale=:log10)
      plot!(reward[1:i], title="Reward", subplot=5, c=:blue, legend=false)
      heatmap!(L1, title="l1", subplot=6, legend=false, colorbar=true)
      heatmap!(L2, title="l2", subplot=7, legend=false, colorbar=true)
      update_plot!(pfig[1], max_history=3, attribute=:linecolor)
      update_plot!(pfig[2], max_history=5, attribute=:linecolor)
      update_plot!(pfig[3], max_history=6, attribute=:linecolor)
      update_plot!(pfig[4], max_history=1)
      update_plot!(pfig[5], max_history=1, attribute=:linecolor)
      update_plot!(pfig[6], max_history=1)
      gui(pfig)
    end
    
    ## Solve DPG ===================================================================
    run(session, initialize_all_variables())
    cost, mem = dpg(ACtype,dpgopts, funs, progressplot)
    
    
    
    pyplot()
    cp = linspace(-1.2,0.5,40)
    cv = linspace(-0.07,0.07,40)
    sgrid = meshgrid(cp,cv)
    qgrid = map(sgrid...) do p,v
        maximum(Q([p,v],[a],1) for a in linspace(-1,1,9))
    end;
    pgrid = map(sgrid...) do p,v
        maximum(Q([p,v],μ([p,v],0),1) for a in linspace(-1,1,9))
    end;
    surface(sgrid...,qgrid, xlabel="Position", ylabel="Velocity", zlabel="Q^*")
    surface(sgrid...,pgrid, xlabel="Position", ylabel="Velocity", zlabel="Q^*")