stochastic_gradient_descent.jl 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
function stochastic_gradient_descent(f, g, x, N, n_state; lambda = 0.1, maxIter = 10, tolG = 1e-7, tolX = 1e-10, show_trace=true, timeout = 1000, batch_size=10)

    converged = false
    x_converged = false
    g_converged = false
    need_jacobian = true
    iterCt = 0

    f_calls = 0
    g_calls = 0

    fcur = f(x)
    f_calls += 1
    residual = rms(fcur)
    lambdai = lambda
    t0 = time()

    # Maintain a trace of the system.
    tr = Optim.OptimizationTrace()
    if show_trace
        d = Dict("lambda" => lambda)
        os = Optim.OptimizationState(1, residual, NaN)

        push!(tr, os)
        println("Iter:0, f:$(round(os.value,5)), ||g||:$(0))")
    end



    for t = 1:maxIter
        iterCt = t
        for i = 1:batch_size:N
            grad = g(x,i:min(i+batch_size,N))
            g_calls += 1
            x -= lambdai*grad
            x = saturatePrecision(x,n_state)


        end
        grad = g(x,1:N)
        lambdai = lambda / (1+t)
        fcur = f(x)
        f_calls += 1
        residual = rms(fcur)

        if show_trace && (t < 5 || t % 5 == 0)
            gradnorm = norm(grad)
            os = Optim.OptimizationState(t, residual, gradnorm)
            push!(tr, os)
            println("Iter:$t, f:$(round(os.value,5)), ||g||:$(round(gradnorm,5))")
        end


        if time()-t0 > timeout
            display("stochastic_gradient_descent: timeout $(timeout)s reached ($(time()-t0)s)")
            break
        end

    end
    Optim.MultivariateOptimizationResults("stochastic_gradient_descent", x0, x, residual, t, !converged, x_converged, 0.0, false, 0.0, g_converged, tolG, tr, f_calls, g_calls)









end