From 62456129d6514937ee5863613c103ead8c1208f0 Mon Sep 17 00:00:00 2001 From: baggepinnen <cont-frb@ulund.org> Date: Wed, 11 Oct 2017 09:58:34 +0200 Subject: [PATCH] update --- src/LTVModelsBase.jl | 44 +++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/src/LTVModelsBase.jl b/src/LTVModelsBase.jl index 8cfcf69..6e2d3e3 100644 --- a/src/LTVModelsBase.jl +++ b/src/LTVModelsBase.jl @@ -1,10 +1,12 @@ module LTVModelsBase # Interface exports -export AbstractModel, AbstractCost, ModelAndCost, cost, -cost_final, dc,calculate_cost,calculate_final_cost, -fit_model, predict, df,costfun, covariance, LTVStateSpaceModel, -SimpleLTVModel, KalmanModel, GMMModel +export AbstractModel, AbstractCost, ModelAndCost,f, +dc,calculate_cost,calculate_final_cost, +fit_model, predict, df,costfun, LTVStateSpaceModel, +SimpleLTVModel, covariance + +export rms, sse, nrmse rms(x) = sqrt(mean(x.^2)) @@ -46,28 +48,6 @@ end SimpleLTVModel(At,Bt,extend::Bool) = SimpleLTVModel{eltype(At)}(At,Bt,extend) -mutable struct KalmanModel{T} <: LTVStateSpaceModel - At::Array{T,3} - Bt::Array{T,3} - Pt::Array{T,3} - extended::Bool - function KalmanModel{T}(At::Array{T,3},Bt::Array{T,3},Pt::Array{T,3},extend::Bool) - if extend - At = cat(3,At,At[:,:,end]) - Bt = cat(3,Bt,Bt[:,:,end]) - Pt = cat(3,Pt,Pt[:,:,end]) - end - return new(At,Bt,Pt,extend) - end -end - -KalmanModel(At,Bt,Pt,extend::Bool=false) = KalmanModel{eltype(At)}(At,Bt,Pt,extend) - -mutable struct GMMModel <: AbstractModel - M - dynamics - T -end """ model = fit_model(::Type{AbstractModel}, x,u)::AbstractModel @@ -172,7 +152,7 @@ function f(modelcost::ModelAndCost, x, u, i) predict(modelcost.model, x, u, i) end -function constfun(modelcost::ModelAndCost, x, u) +function costfun(modelcost::ModelAndCost, x, u) calculate_cost(modelcost.cost, x, u) end @@ -188,4 +168,14 @@ function df(modelcost::ModelAndCost, x, u) end + +function covariance(model::AbstractModel, x, u) + xhat = similar(x) + xhat[:,1] = x[:,1] + for i = 1:size(x,2)-1 + xhat[:,i+1] = model.At[:,:,i]*x[:,i] + model.Bt[:,:,i]*u[:,i] + end + return cov(x-xhat, 2) +end + end # module -- GitLab