00001
00013 #ifndef KF_H
00014 #define KF_H
00015
00016
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021
00022 namespace bdm
00023 {
00024
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035 protected:
00037 int dimx;
00039 int dimy;
00041 int dimu;
00043 mat A;
00045 mat B;
00047 mat C;
00049 mat D;
00051 sq_T Q;
00053 sq_T R;
00054 public:
00055 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0);
00057 void validate();
00059 void from_setting (const Setting &set) {
00060 UI::get (A, set, "A", UI::compulsory);
00061 UI::get (B, set, "B", UI::compulsory);
00062 UI::get (C, set, "C", UI::compulsory);
00063 UI::get (D, set, "D", UI::compulsory);
00064 mat Qtm, Rtm;
00065 if(!UI::get(Qtm, set, "Q", UI::optional)){
00066 vec dq;
00067 UI::get(dq, set, "dQ", UI::compulsory);
00068 Qtm=diag(dq);
00069 }
00070 if(!UI::get(Rtm, set, "R", UI::optional)){
00071 vec dr;
00072 UI::get(dr, set, "dQ", UI::compulsory);
00073 Rtm=diag(dr);
00074 }
00075 R=Rtm;
00076 Q=Qtm;
00077
00078 validate();
00079 }
00081 int _dimx(){return dimx;}
00083 int _dimy(){return dimy;}
00085 int _dimu(){return dimu;}
00087 mat _A(){return A;}
00089 mat _B(){return B;}
00091 mat _C(){return C;}
00093 mat _D(){return D;}
00095 sq_T _Q(){return Q;}
00097 sq_T _R(){return R;}
00098 };
00099
00101 template<class sq_T>
00102 class Kalman: public BM, public StateSpace<sq_T>
00103 {
00104 protected:
00106 RV yrv;
00108 RV urv;
00110 mat _K;
00112 shared_ptr<enorm<sq_T> > est;
00114 enorm<sq_T> fy;
00115 public:
00116 Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(), est(new enorm<sq_T>){}
00117 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00118 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00120 const enorm<sq_T>& posterior() const {return *est.get();}
00122 shared_ptr<epdf> shared_posterior() {return est;}
00124 void from_setting (const Setting &set) {
00125 StateSpace<sq_T>::from_setting(set);
00126
00127 mat P0; vec mu0;
00128 UI::get(mu0, set, "mu0", UI::optional);
00129 UI::get(P0, set, "P0", UI::optional);
00130 set_statistics(mu0,P0);
00131
00132 UI::get (yrv, set, "yrv", UI::optional);
00133 UI::get (urv, set, "urv", UI::optional);
00134 set_drv(concat(yrv,urv));
00135
00136 validate();
00137 }
00138 void validate() {
00139 StateSpace<sq_T>::validate();
00140 bdm_assert_debug(est->dimension(), "Statistics and model parameters mismatch");
00141 }
00142 };
00147 class KalmanFull : public Kalman<fsqmat>
00148 {
00149 public:
00151 KalmanFull() :Kalman<fsqmat>(){};
00153 void bayes (const vec &dt);
00154 };
00155 UIREGISTER(KalmanFull);
00156
00157
00165 class KalmanCh : public Kalman<chmat>
00166 {
00167 protected:
00170 mat preA;
00172 mat postA;
00174 public:
00176 BM* _copy_() const {
00177 KalmanCh* K = new KalmanCh;
00178 K->set_parameters (A, B, C, D, Q, R);
00179 K->set_statistics (est->_mu(), est->_R());
00180 return K;
00181 }
00183 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00185 void initialize();
00186
00200 void bayes (const vec &dt);
00201
00202 void from_setting(const Setting &set){
00203 Kalman<chmat>::from_setting(set);
00204 initialize();
00205 }
00206 };
00207 UIREGISTER(KalmanCh);
00208
00214 class EKFfull : public KalmanFull
00215 {
00216 protected:
00218 shared_ptr<diffbifn> pfxu;
00219
00221 shared_ptr<diffbifn> phxu;
00222
00223 public:
00225 EKFfull ();
00226
00228 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00229
00231 void bayes (const vec &dt);
00233 void set_statistics (const vec &mu0, const mat &P0) {
00234 est->set_parameters (mu0, P0);
00235 };
00236 const mat _R() {
00237 return est->_R().to_mat();
00238 }
00239 };
00240 UIREGISTER(EKFfull);
00241
00242
00249 class EKFCh : public KalmanCh
00250 {
00251 protected:
00253 shared_ptr<diffbifn> pfxu;
00254
00256 shared_ptr<diffbifn> phxu;
00257 public:
00259 BM* _copy_() const {
00260 EKFCh* E = new EKFCh;
00261 E->set_parameters (pfxu, phxu, Q, R);
00262 E->set_statistics (est->_mu(), est->_R());
00263 return E;
00264 }
00266 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00267
00269 void bayes (const vec &dt);
00270
00271 void from_setting (const Setting &set);
00272
00273
00274
00275 };
00276
00277 UIREGISTER (EKFCh);
00278 SHAREDPTR (EKFCh);
00279
00280
00282
00290 class MultiModel: public BM
00291 {
00292 protected:
00294 Array<EKFCh*> Models;
00296 vec w;
00298 vec _lls;
00300 int policy;
00302 enorm<chmat> est;
00303 public:
00304 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00305 Models = A;
00306 w.set_length (A.length());
00307 _lls.set_length (A.length());
00308 policy = pol0;
00309
00310 est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00311 est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00312 }
00313 void bayes (const vec &dt) {
00314 int n = Models.length();
00315 int i;
00316 for (i = 0; i < n; i++) {
00317 Models (i)->bayes (dt);
00318 _lls (i) = Models (i)->_ll();
00319 }
00320 double mlls = max (_lls);
00321 w = exp (_lls - mlls);
00322 w /= sum (w);
00323
00324 switch (policy) {
00325 case 1: {
00326 int mi = max_index (w);
00327 const enorm<chmat> &st = Models (mi)->posterior() ;
00328 est.set_parameters (st.mean(), st._R());
00329 }
00330 break;
00331 default:
00332 bdm_error ("unknown policy");
00333 }
00334
00335 for (i = 0; i < n; i++) {
00336 Models (i)->set_statistics (est.mean(), est._R());
00337 }
00338 }
00340 const enorm<chmat>& posterior() const {
00341 return est;
00342 }
00343
00344 void from_setting (const Setting &set);
00345
00346
00347
00348 };
00349
00350 UIREGISTER (MultiModel);
00351 SHAREDPTR (MultiModel);
00352
00354
00361 template<class sq_T>
00362 shared_ptr<StateSpace<fsqmat> > to_state_space(const mlnorm<sq_T > &ml, ivec &theta_in_A, ivec &theta_in_C){
00363
00364
00365 ivec yids= unique(ml._rv()._ids());
00366 ivec dids= unique(ml._rvc()._ids());
00367 ivec uids=unique_complement(dids, yids);
00368
00369 RV yrv_uni = RV(yids,zeros_i(yids.length()));
00370 RV urv_uni = RV(uids,zeros_i(yids.length()));
00371
00372 bdm_assert_debug(yrv_uni._dsize()==urv_uni._dsize()==1, "Only for SISO so far..." );
00373
00374 RV xrv;
00375 RV Crv;
00376 int td=ml._rvc().mintd();
00377 for (int t=-1;t>=td;t--){
00378 xrv.add(yrv_uni);
00379 Crv.add(urv_uni);
00380 }
00381
00382 int dimx = xrv._dsize();
00383
00384 theta_in_A = ml._rv().dataind(xrv);
00385 theta_in_C = ml._rvc().dataind(xrv);
00386
00387
00388 vec A1row = zeros(xrv._dsize());
00389 vec C1row = zeros(xrv._dsize());
00390 vec theta = ml._A().get_row(0);
00391 set_subvector( A1row, theta_in_A, theta);
00392 set_subvector( C1row, theta_in_C, theta);
00393
00394 StateSpace<fsqmat> stsp=new StateSpace<fsqmat>();
00395 mat A=zeros(dimx,dimx);
00396 A.set_row(0,A1row);
00397 for (int j=1; j<dimx; j++){A(j,j-1)=1.0;}
00398 mat B=zeros(dimx,1);
00399 B(0) = 1.0;
00400 mat C=zeros(1,dimx);
00401 C.set_row(0,C1row);
00402 stsp.set_parameters(A,B,C,zeros(1,1), zeros(dimx,dimx), ml._R());
00403 return stsp;
00404 }
00405
00407
00408 template<class sq_T>
00409 void StateSpace<sq_T>::set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0)
00410 {
00411 dimx = A0.rows();
00412 dimu = B0.cols();
00413 dimy = C0.rows();
00414
00415 A = A0;
00416 B = B0;
00417 C = C0;
00418 D = D0;
00419 R = R0;
00420 Q = Q0;
00421 validate();
00422 }
00423
00424 template<class sq_T>
00425 void StateSpace<sq_T>::validate(){
00426 bdm_assert_debug (A.cols() == dimx, "KalmanFull: A is not square");
00427 bdm_assert_debug (B.rows() == dimx, "KalmanFull: B is not compatible");
00428 bdm_assert_debug (C.cols() == dimx, "KalmanFull: C is not square");
00429 bdm_assert_debug ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
00430 bdm_assert_debug ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00431 bdm_assert_debug ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
00432 }
00433
00434 }
00435 #endif // KF_H
00436
00437