00001
00013 #ifndef KF_H
00014 #define KF_H
00015
00016
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021
00022 namespace bdm
00023 {
00024
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035 protected:
00037 int dimx;
00039 int dimy;
00041 int dimu;
00043 mat A;
00045 mat B;
00047 mat C;
00049 mat D;
00051 sq_T Q;
00053 sq_T R;
00054 public:
00055 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00057 StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
00059 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0);
00061 void validate();
00063 void from_setting (const Setting &set) {
00064 UI::get (A, set, "A", UI::compulsory);
00065 UI::get (B, set, "B", UI::compulsory);
00066 UI::get (C, set, "C", UI::compulsory);
00067 UI::get (D, set, "D", UI::compulsory);
00068 mat Qtm, Rtm;
00069 if(!UI::get(Qtm, set, "Q", UI::optional)){
00070 vec dq;
00071 UI::get(dq, set, "dQ", UI::compulsory);
00072 Qtm=diag(dq);
00073 }
00074 if(!UI::get(Rtm, set, "R", UI::optional)){
00075 vec dr;
00076 UI::get(dr, set, "dQ", UI::compulsory);
00077 Rtm=diag(dr);
00078 }
00079 R=Rtm;
00080 Q=Qtm;
00081
00082 validate();
00083 }
00085 int _dimx(){return dimx;}
00087 int _dimy(){return dimy;}
00089 int _dimu(){return dimu;}
00091 const mat& _A() const {return A;}
00093 const mat& _B()const {return B;}
00095 const mat& _C()const {return C;}
00097 const mat& _D()const {return D;}
00099 const sq_T& _Q()const {return Q;}
00101 const sq_T& _R()const {return R;}
00102 };
00103
00105 template<class sq_T>
00106 class Kalman: public BM, public StateSpace<sq_T>
00107 {
00108 protected:
00110 RV yrv;
00112 RV urv;
00114 mat _K;
00116 shared_ptr<enorm<sq_T> > est;
00118 enorm<sq_T> fy;
00119 public:
00120 Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(), est(new enorm<sq_T>){}
00122 Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K), est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
00124 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00126 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00128 const enorm<sq_T>& posterior() const {return *est.get();}
00130 shared_ptr<epdf> shared_posterior() {return est;}
00132 void from_setting (const Setting &set) {
00133 StateSpace<sq_T>::from_setting(set);
00134
00135 mat P0; vec mu0;
00136 UI::get(mu0, set, "mu0", UI::optional);
00137 UI::get(P0, set, "P0", UI::optional);
00138 set_statistics(mu0,P0);
00139
00140 UI::get (yrv, set, "yrv", UI::optional);
00141 UI::get (urv, set, "urv", UI::optional);
00142 set_drv(concat(yrv,urv));
00143
00144 validate();
00145 }
00147 void validate() {
00148 StateSpace<sq_T>::validate();
00149 bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
00150 }
00151 };
00156 class KalmanFull : public Kalman<fsqmat>
00157 {
00158 public:
00160 KalmanFull() :Kalman<fsqmat>(){};
00162 void bayes (const vec &dt);
00163 BM* _copy_() const {
00164 KalmanFull* K = new KalmanFull;
00165 K->set_parameters (A, B, C, D, Q, R);
00166 K->set_statistics (est->_mu(), est->_R());
00167 return K;
00168 }
00169 };
00170 UIREGISTER(KalmanFull);
00171
00172
00180 class KalmanCh : public Kalman<chmat>
00181 {
00182 protected:
00185 mat preA;
00187 mat postA;
00189 public:
00191 BM* _copy_() const {
00192 KalmanCh* K = new KalmanCh;
00193 K->set_parameters (A, B, C, D, Q, R);
00194 K->set_statistics (est->_mu(), est->_R());
00195 return K;
00196 }
00198 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00200 void initialize();
00201
00215 void bayes (const vec &dt);
00216
00217 void from_setting(const Setting &set){
00218 Kalman<chmat>::from_setting(set);
00219 initialize();
00220 }
00221 };
00222 UIREGISTER(KalmanCh);
00223
00229 class EKFfull : public KalmanFull
00230 {
00231 protected:
00233 shared_ptr<diffbifn> pfxu;
00234
00236 shared_ptr<diffbifn> phxu;
00237
00238 public:
00240 EKFfull ();
00241
00243 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00244
00246 void bayes (const vec &dt);
00248 void set_statistics (const vec &mu0, const mat &P0) {
00249 est->set_parameters (mu0, P0);
00250 };
00252 const mat _R() {
00253 return est->_R().to_mat();
00254 }
00255 void from_setting (const Setting &set) {
00256 shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
00257 shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
00258
00259
00260 int dim = IM->dimension();
00261 vec mu0;
00262 if ( !UI::get ( mu0, set, "mu0" ) )
00263 mu0 = zeros ( dim );
00264
00265 mat P0;
00266 vec dP0;
00267 if ( UI::get ( dP0, set, "dP0" ) )
00268 P0 = diag ( dP0 );
00269 else if ( !UI::get ( P0, set, "P0" ) )
00270 P0 = eye ( dim );
00271
00272 set_statistics ( mu0, P0 );
00273
00274
00275 vec dQ, dR;
00276 UI::get ( dQ, set, "dQ", UI::compulsory );
00277 UI::get ( dR, set, "dR", UI::compulsory );
00278 set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
00279
00280
00281 shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
00282 set_drv ( *drv );
00283 shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
00284 set_rv ( *rv );
00285
00286 string options;
00287 if ( UI::get ( options, set, "options" ) )
00288 set_options ( options );
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311 validate();
00312 }
00313 void validate() {
00314
00315 }
00316 };
00317 UIREGISTER(EKFfull);
00318
00319
00326 class EKFCh : public KalmanCh
00327 {
00328 protected:
00330 shared_ptr<diffbifn> pfxu;
00331
00333 shared_ptr<diffbifn> phxu;
00334 public:
00336 BM* _copy_() const {
00337 EKFCh* E = new EKFCh;
00338 E->set_parameters (pfxu, phxu, Q, R);
00339 E->set_statistics (est->_mu(), est->_R());
00340 return E;
00341 }
00343 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00344
00346 void bayes (const vec &dt);
00347
00348 void from_setting (const Setting &set);
00349
00350
00351
00352 };
00353
00354 UIREGISTER (EKFCh);
00355 SHAREDPTR (EKFCh);
00356
00357
00359
00367 class MultiModel: public BM
00368 {
00369 protected:
00371 Array<EKFCh*> Models;
00373 vec w;
00375 vec _lls;
00377 int policy;
00379 enorm<chmat> est;
00380 public:
00382 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00383 Models = A;
00384 w.set_length (A.length());
00385 _lls.set_length (A.length());
00386 policy = pol0;
00387
00388 est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00389 est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00390 }
00391 void bayes (const vec &dt) {
00392 int n = Models.length();
00393 int i;
00394 for (i = 0; i < n; i++) {
00395 Models (i)->bayes (dt);
00396 _lls (i) = Models (i)->_ll();
00397 }
00398 double mlls = max (_lls);
00399 w = exp (_lls - mlls);
00400 w /= sum (w);
00401
00402 switch (policy) {
00403 case 1: {
00404 int mi = max_index (w);
00405 const enorm<chmat> &st = Models (mi)->posterior() ;
00406 est.set_parameters (st.mean(), st._R());
00407 }
00408 break;
00409 default:
00410 bdm_error ("unknown policy");
00411 }
00412
00413 for (i = 0; i < n; i++) {
00414 Models (i)->set_statistics (est.mean(), est._R());
00415 }
00416 }
00418 const enorm<chmat>& posterior() const {
00419 return est;
00420 }
00421
00422 void from_setting (const Setting &set);
00423
00424
00425
00426 };
00427
00428 UIREGISTER (MultiModel);
00429 SHAREDPTR (MultiModel);
00430
00432
00440
00441 class StateCanonical: public StateSpace<fsqmat>{
00442 protected:
00444 datalink_part th2A;
00446 datalink_part th2C;
00448 datalink_part th2D;
00450 vec A1row;
00452 vec C1row;
00454 vec D1row;
00455
00456 public:
00458 void connect_mlnorm(const mlnorm<fsqmat> &ml){
00459
00460 const RV &yrv = ml._rv();
00461
00462 RV rgr0 = ml._rvc().remove_time();
00463 RV urv = rgr0.subt(yrv);
00464
00465
00466 bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
00467
00468
00469 RV xrv;
00470 RV Crv;
00471 int td=ml._rvc().mint();
00472
00473 for (int t=-1;t>=td;t--){
00474 xrv.add(yrv.copy_t(t));
00475 Crv.add(urv.copy_t(t));
00476 }
00477
00478 this->dimx = xrv._dsize();
00479 this->dimy = yrv._dsize();
00480 this->dimu = urv._dsize();
00481
00482
00483 th2A.set_connection(xrv, ml._rvc());
00484 th2C.set_connection(Crv, ml._rvc());
00485 th2D.set_connection(urv, ml._rvc());
00486
00487
00488 this->A=zeros(dimx,dimx);
00489 for (int j=1; j<dimx; j++){A(j,j-1)=1.0;}
00490 this->B=zeros(dimx,1);
00491 this->B(0) = 1.0;
00492 this->C=zeros(1,dimx);
00493 this->D=zeros(1,urv._dsize());
00494 this->Q = zeros(dimx,dimx);
00495
00496
00497
00498 this->A1row = zeros(xrv._dsize());
00499 this->C1row = zeros(xrv._dsize());
00500 this->D1row = zeros(urv._dsize());
00501
00502 update_from(ml);
00503 validate();
00504 };
00506 void update_from(const mlnorm<fsqmat> &ml){
00507
00508 vec theta = ml._A().get_row(0);
00509
00510 th2A.filldown(theta,A1row);
00511 th2C.filldown(theta,C1row);
00512 th2D.filldown(theta,D1row);
00513
00514 R = ml._R();
00515
00516 A.set_row(0,A1row);
00517 C.set_row(0,C1row+D1row*A1row);
00518 D.set_row(0,D1row);
00519
00520 }
00521 };
00522
00524
00525 template<class sq_T>
00526 void StateSpace<sq_T>::set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0)
00527 {
00528
00529 A = A0;
00530 B = B0;
00531 C = C0;
00532 D = D0;
00533 R = R0;
00534 Q = Q0;
00535 validate();
00536 }
00537
00538 template<class sq_T>
00539 void StateSpace<sq_T>::validate(){
00540 dimx = A.rows();
00541 dimu = B.cols();
00542 dimy = C.rows();
00543 bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
00544 bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
00545 bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
00546 bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
00547 bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00548 bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
00549 }
00550
00551 }
00552 #endif // KF_H
00553
00554