Select Git revision
trainRBF.jl
trainRBF.jl 7.01 KiB
using Optim
using Devectorize
import ASCIIPlots.scatterplot
function trainRBF(b, Nbasis, σ, ϕtrain; nonlin=false, liniters=20,nonliniters=5, λ=1e-8, ND=false, normalized=false)
limits = [minimum(ϕtrain,1)' maximum(ϕtrain,1)']
Nseries = size(ϕtrain,2);
NBASIS::Int64 = Nbasis^Nseries
if Nseries == 1
ND = false;
end
Ψ = Array{Float64,2}
if ND
Ψ = zeros(size(ϕtrain,1),NBASIS)
else
Ψ = zeros(size(ϕtrain,1),Nbasis);
end
assert(length(σ) == Nseries)
σvec = zeros(Nbasis,Nseries)
centers = zeros(Nbasis,Nseries)
for s = 1:Nseries
l = limits[s,2]-limits[s,1];
d = l/Nbasis;
lc = [(d/2):d:l]+limits[s,1]
centers[:,s] = lc;
σvec[:,s] = σ[s].*ones(Nbasis);
end
# params = Array{Float64,1}
if ND
params = getCenters(centers,σvec)
getΨVecNd!(Ψ, ϕtrain , params, normalized=normalized)
else
getΨVec!(Ψ,σvec,ϕtrain,centers, normalized=normalized)
end
figure(); pp= imagesc(Ψ); display(pp);
Nparams = size(Ψ,2)
# @show size(Ψ)
# @show size(b)
w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b);
if nonlin
if ND
function costfunc(Ψ,w,ϕtrain,b , params)
getΨVecNd!(Ψ, ϕtrain , params, w, normalized=normalized)
f = Ψ*w - b
# @devec c = sum(f.^2);
# c
end
cost(X) = costfunc(Ψ,w,ϕtrain,b,X);
function g(xparams)
return getΨVecNd!(Ψ, ϕtrain , xparams, w, normalized=normalized)
#figure();plot(storage); ylabel("gradient"); display(pp)
end
else
function costfunc(Ψ,w,ϕtrain,centers,σvec,b)
getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
f = Ψ*w - b
end
cost(X) = costfunc(Ψ,w,ϕtrain,reshape(X[1:(Nbasis*Nseries)],Nbasis,Nseries),reshape(X[(Nbasis*Nseries)+1:end],Nbasis,Nseries),b);
function g(xparams)
return getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
#figure();plot(storage); ylabel("gradient"); display(pp)
end
end
# function cost∇cost(X,storage)
# h = 1e-10
# x = X + h*im
# y = cost(x)
# storage[:] = imag(y)./h
# return real(y)
# end
# ∇ = DifferentiableFunction(cost,∇cost, cost∇cost)
cold = 1e20
fvec = zeros(liniters+1)
fvec[1] = ND ? sum(cost(params).^2) : sum(cost([centers[:],σvec[:]]).^2)
finalIter = 1
useCuckoo = false
for i = 1:liniters
X0 = ND ? params : [centers[:],σvec[:]]
# test = copy(X0)
# show(∇cost(X0,test))
try
# test = X0
# @show ∇cost(X0,test)
if true#!useCuckoo #i % 2 == 1
# @time res = optimize(cost, X0,
# method = :l_bfgs,
# iterations = iters,
# grtol = 1e-5,
# xtol = 1e-8,
# ftol = 1e-8)
@time res = Optim.levenberg_marquardt(cost,g, X0,
maxIter = nonliniters,
tolG = 1e-5,
tolX = 1e-8,
show_trace=false)
X = res.minimum
fvec[i+1] = res.f_minimum
else
display("Using cuckoo search to escape local minimum")
@time (bestnest,fmin) = cuckoo_search(x -> sum(cost(x).^2),X0;Lb=-inf(Float64),Ub=inf(Float64),n=30,pa=0.25, Tol=1.0e-5, max_iter = 10*nonliniters+100)
X = bestnest
fvec[i+1] = fmin
end
if ND
params = X
getΨVecNd!(Ψ, ϕtrain , params,b, w, normalized=normalized)
else
centers = reshape(X[1:(Nbasis*Nseries)],Nbasis,Nseries)
σvec = reshape(X[(Nbasis*Nseries)+1:end],Nbasis,Nseries)
getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
end
w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b)
display("Non-linear optimization, iteration $i")
# iters *= convert(Int64,round(20^(1/liniters)))
f = Ψ*w - b
c = fvec[i+1]
if abs(cold-c) < 1e-10
if useCuckoo
finalIter += 1
display("No significant change in function value")
break
end
useCuckoo = true
elseif useCuckoo
useCuckoo = false
end
cold = c
catch ex
display("Optimization failed, using current best point")
display(ex)
if ND
getΨVecNd!(Ψ, ϕtrain , params, w, normalized=normalized)
else
getΨVec!(Ψ,σvec,ϕtrain,centers,w, normalized=normalized)
end
w = ((Ψ'*Ψ + λ*eye(Nparams))\Ψ'*b)
break
end
finalIter += 1
try
display(scatterplot(fvec[1:finalIter],sym='*'))
catch
display("Plot failed")
end
end
figure();pp = plot(fvec[1:finalIter],"o"); xlabel("Iteration"); title("Best function value"); display(pp)
diffFvec = diff(fvec[1:finalIter])
figure();pp = plot(diffFvec[1:2:end],"o");hold(true);plot(diffFvec[2:2:end],"og"); xlabel("Iteration"); title("Best function value decrease"); display(pp)
end
f = Ψ*w;
figure();pp = plot(ϕtrain[:,1],".b");hold(true)
plot(ϕtrain,f[:,1],"--b"); xlabel("ϕ"); ylabel("output"); display(pp)
if size(b,2) > 1
hold(true)
plot(ϕtrain[:,2],"g.")
plot(ϕtrain,f[:,2],"g--")
display(pp)
figure(); pp= plot(f[:,1],f[:,2],"--"); hold(true)
plot(b[:,1][:,2]); xlabel("output dim 1");ylabel("output dim 2")
display(pp)
end
if ND
return w,params, Ψ
else
return w,σ, centers, Ψ
end
end