Skip to content
Snippets Groups Projects
Select Git revision
  • 608645fdefec7f43cf9c594f26d14459cba9339d
  • master default protected
  • mattias
  • sobolfast
4 results

trainRBF.jl

  • trainRBF.jl 7.01 KiB
    using Optim
    using Devectorize
    import ASCIIPlots.scatterplot
    function trainRBF(b, Nbasis, σ, ϕtrain; nonlin=false, liniters=20,nonliniters=5, λ=1e-8, ND=false, normalized=false)
        limits = [minimum(ϕtrain,1)' maximum(ϕtrain,1)']
        Nseries = size(ϕtrain,2);
        NBASIS::Int64 = Nbasis^Nseries
        if Nseries == 1
            ND = false;
        end
        Ψ = Array{Float64,2}
        if ND
            Ψ = zeros(size(ϕtrain,1),NBASIS)
        else
            Ψ = zeros(size(ϕtrain,1),Nbasis);
        end
        assert(length(σ) == Nseries)
        σvec = zeros(Nbasis,Nseries)
        centers = zeros(Nbasis,Nseries)
        for s = 1:Nseries
            l = limits[s,2]-limits[s,1];
            d = l/Nbasis;
            lc = [(d/2):d:l]+limits[s,1]
            centers[:,s] = lc;
            σvec[:,s] = σ[s].*ones(Nbasis);
        end
    
        #     params = Array{Float64,1}
    
        if ND
            params = getCenters(centers,σvec)
            getΨVecNd!(Ψ, ϕtrain , params, normalized=normalized)
        else
            getΨVec!(Ψ,σvec,ϕtrain,centers, normalized=normalized)
        end
    
        figure(); pp= imagesc(Ψ); display(pp);
        Nparams = size(Ψ,2)
        #     @show size(Ψ)
        #     @show size(b)
    
        w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b);
    
    
    
    
        if nonlin
            if ND
                function costfunc(Ψ,w,ϕtrain,b , params)
                    getΨVecNd!(Ψ, ϕtrain , params, w, normalized=normalized)
                    f = Ψ*w - b
                    #             @devec c = sum(f.^2);
                    #             c
                end
                cost(X) = costfunc(Ψ,w,ϕtrain,b,X);
                function g(xparams)
                    return getΨVecNd!(Ψ, ϕtrain , xparams, w, normalized=normalized)
                    #figure();plot(storage); ylabel("gradient");  display(pp)
                end
            else
                function costfunc(Ψ,w,ϕtrain,centers,σvec,b)
                    getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
                    f = Ψ*w - b
                end
                cost(X) = costfunc(Ψ,w,ϕtrain,reshape(X[1:(Nbasis*Nseries)],Nbasis,Nseries),reshape(X[(Nbasis*Nseries)+1:end],Nbasis,Nseries),b);
                function g(xparams)
                    return getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
                    #figure();plot(storage); ylabel("gradient");  display(pp)
                end
            end
    
    
            #         function cost∇cost(X,storage)
            #             h = 1e-10
            #             x = X + h*im
            #             y = cost(x)
            #             storage[:] = imag(y)./h
            #             return real(y)
            #         end
    
            #         ∇ = DifferentiableFunction(cost,∇cost, cost∇cost)
            cold = 1e20
            fvec = zeros(liniters+1)
            fvec[1] = ND ? sum(cost(params).^2) : sum(cost([centers[:],σvec[:]]).^2)
            finalIter = 1
            useCuckoo = false
            for i = 1:liniters
                X0 = ND ? params : [centers[:],σvec[:]]
                #                         test = copy(X0)
                #                         show(∇cost(X0,test))
                try
                    #                 test = X0
                    #                 @show ∇cost(X0,test)
                    if true#!useCuckoo #i % 2 == 1
                        #                     @time res = optimize(cost, X0,
                        #                                          method = :l_bfgs,
                        #                                          iterations = iters,
                        #                                          grtol = 1e-5,
                        #                                          xtol = 1e-8,
                        #                                          ftol = 1e-8)
                        @time res = Optim.levenberg_marquardt(cost,g, X0,
                                                              maxIter = nonliniters,
                                                              tolG = 1e-5,
                                                              tolX = 1e-8,
                                                              show_trace=false)
                        X = res.minimum
                        fvec[i+1] = res.f_minimum
                    else
                        display("Using cuckoo search to escape local minimum")
                        @time (bestnest,fmin) = cuckoo_search(x -> sum(cost(x).^2),X0;Lb=-inf(Float64),Ub=inf(Float64),n=30,pa=0.25, Tol=1.0e-5, max_iter = 10*nonliniters+100)
                        X = bestnest
                        fvec[i+1] = fmin
                    end
                    if ND
                        params = X
                        getΨVecNd!(Ψ, ϕtrain , params,b, w, normalized=normalized)
                    else
                        centers = reshape(X[1:(Nbasis*Nseries)],Nbasis,Nseries)
                        σvec = reshape(X[(Nbasis*Nseries)+1:end],Nbasis,Nseries)
                        getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
                    end
                    w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b)
                    display("Non-linear optimization, iteration $i")
                    #                 iters *= convert(Int64,round(20^(1/liniters)))
                    f = Ψ*w - b
                    c = fvec[i+1]
                    if abs(cold-c) < 1e-10
                        if useCuckoo
                            finalIter += 1
                            display("No significant change in function value")
                            break
                        end
                        useCuckoo = true
                    elseif useCuckoo
                        useCuckoo = false
                    end
                    cold = c
                catch ex
                    display("Optimization failed, using current best point")
                    display(ex)
                    if ND
                        getΨVecNd!(Ψ, ϕtrain , params, w, normalized=normalized)
                    else
                        getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
                    end
                    w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b)
                    break
                end
    
                finalIter += 1
                try
                    display(scatterplot(fvec[1:finalIter],sym='*'))
                catch
                    display("Plot failed")
                end
            end
            figure();pp = plot(fvec[1:finalIter],"o"); xlabel("Iteration"); title("Best function value");  display(pp)
            diffFvec = diff(fvec[1:finalIter])
            figure();pp = plot(diffFvec[1:2:end],"o");hold(true);plot(diffFvec[2:2:end],"og"); xlabel("Iteration"); title("Best function value decrease");  display(pp)
        end
        f = Ψ*w;
        figure();pp = plot(ϕtrain[:,1],".b");hold(true)
        plot(ϕtrain,f[:,1],"--b"); xlabel("ϕ"); ylabel("output");  display(pp)
    
        if size(b,2) > 1
            hold(true)
            plot(ϕtrain[:,2],"g.")
            plot(ϕtrain,f[:,2],"g--")
            display(pp)
            figure(); pp= plot(f[:,1],f[:,2],"--"); hold(true)
            plot(b[:,1][:,2]); xlabel("output dim 1");ylabel("output dim 2")
            display(pp)
        end
        if ND
            return w,params, Ψ
        else
            return w,σ, centers, Ψ
        end
    end