00001
00013 #ifndef KF_H
00014 #define KF_H
00015
00016
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021
00022 namespace bdm
00023 {
00024
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035 protected:
00037 int dimx;
00039 int dimy;
00041 int dimu;
00043 mat A;
00045 mat B;
00047 mat C;
00049 mat D;
00051 sq_T Q;
00053 sq_T R;
00054 public:
00055 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056 StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
00057 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0);
00058 void validate();
00060 void from_setting (const Setting &set) {
00061 UI::get (A, set, "A", UI::compulsory);
00062 UI::get (B, set, "B", UI::compulsory);
00063 UI::get (C, set, "C", UI::compulsory);
00064 UI::get (D, set, "D", UI::compulsory);
00065 mat Qtm, Rtm;
00066 if(!UI::get(Qtm, set, "Q", UI::optional)){
00067 vec dq;
00068 UI::get(dq, set, "dQ", UI::compulsory);
00069 Qtm=diag(dq);
00070 }
00071 if(!UI::get(Rtm, set, "R", UI::optional)){
00072 vec dr;
00073 UI::get(dr, set, "dQ", UI::compulsory);
00074 Rtm=diag(dr);
00075 }
00076 R=Rtm;
00077 Q=Qtm;
00078
00079 validate();
00080 }
00082 int _dimx(){return dimx;}
00084 int _dimy(){return dimy;}
00086 int _dimu(){return dimu;}
00088 const mat& _A() const {return A;}
00090 const mat& _B()const {return B;}
00092 const mat& _C()const {return C;}
00094 const mat& _D()const {return D;}
00096 const sq_T& _Q()const {return Q;}
00098 const sq_T& _R()const {return R;}
00099 };
00100
00102 template<class sq_T>
00103 class Kalman: public BM, public StateSpace<sq_T>
00104 {
00105 protected:
00107 RV yrv;
00109 RV urv;
00111 mat _K;
00113 shared_ptr<enorm<sq_T> > est;
00115 enorm<sq_T> fy;
00116 public:
00117 Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(), est(new enorm<sq_T>){}
00118 Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K), est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
00119 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00120 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00122 const enorm<sq_T>& posterior() const {return *est.get();}
00124 shared_ptr<epdf> shared_posterior() {return est;}
00126 void from_setting (const Setting &set) {
00127 StateSpace<sq_T>::from_setting(set);
00128
00129 mat P0; vec mu0;
00130 UI::get(mu0, set, "mu0", UI::optional);
00131 UI::get(P0, set, "P0", UI::optional);
00132 set_statistics(mu0,P0);
00133
00134 UI::get (yrv, set, "yrv", UI::optional);
00135 UI::get (urv, set, "urv", UI::optional);
00136 set_drv(concat(yrv,urv));
00137
00138 validate();
00139 }
00140 void validate() {
00141 StateSpace<sq_T>::validate();
00142 bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
00143 }
00144 };
00149 class KalmanFull : public Kalman<fsqmat>
00150 {
00151 public:
00153 KalmanFull() :Kalman<fsqmat>(){};
00155 void bayes (const vec &dt);
00156 BM* _copy_() const {
00157 KalmanFull* K = new KalmanFull;
00158 K->set_parameters (A, B, C, D, Q, R);
00159 K->set_statistics (est->_mu(), est->_R());
00160 return K;
00161 }
00162 };
00163 UIREGISTER(KalmanFull);
00164
00165
00173 class KalmanCh : public Kalman<chmat>
00174 {
00175 protected:
00178 mat preA;
00180 mat postA;
00182 public:
00184 BM* _copy_() const {
00185 KalmanCh* K = new KalmanCh;
00186 K->set_parameters (A, B, C, D, Q, R);
00187 K->set_statistics (est->_mu(), est->_R());
00188 return K;
00189 }
00191 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00193 void initialize();
00194
00208 void bayes (const vec &dt);
00209
00210 void from_setting(const Setting &set){
00211 Kalman<chmat>::from_setting(set);
00212 initialize();
00213 }
00214 };
00215 UIREGISTER(KalmanCh);
00216
00222 class EKFfull : public KalmanFull
00223 {
00224 protected:
00226 shared_ptr<diffbifn> pfxu;
00227
00229 shared_ptr<diffbifn> phxu;
00230
00231 public:
00233 EKFfull ();
00234
00236 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00237
00239 void bayes (const vec &dt);
00241 void set_statistics (const vec &mu0, const mat &P0) {
00242 est->set_parameters (mu0, P0);
00243 };
00244 const mat _R() {
00245 return est->_R().to_mat();
00246 }
00247 void from_setting (const Setting &set) {
00248 shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
00249 shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
00250
00251
00252 int dim = IM->dimension();
00253 vec mu0;
00254 if ( !UI::get ( mu0, set, "mu0" ) )
00255 mu0 = zeros ( dim );
00256
00257 mat P0;
00258 vec dP0;
00259 if ( UI::get ( dP0, set, "dP0" ) )
00260 P0 = diag ( dP0 );
00261 else if ( !UI::get ( P0, set, "P0" ) )
00262 P0 = eye ( dim );
00263
00264 set_statistics ( mu0, P0 );
00265
00266
00267 vec dQ, dR;
00268 UI::get ( dQ, set, "dQ", UI::compulsory );
00269 UI::get ( dR, set, "dR", UI::compulsory );
00270 set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
00271
00272
00273 shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
00274 set_drv ( *drv );
00275 shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
00276 set_rv ( *rv );
00277
00278 string options;
00279 if ( UI::get ( options, set, "options" ) )
00280 set_options ( options );
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303 validate();
00304 }
00305 void validate() {
00306
00307 }
00308 };
00309 UIREGISTER(EKFfull);
00310
00311
00318 class EKFCh : public KalmanCh
00319 {
00320 protected:
00322 shared_ptr<diffbifn> pfxu;
00323
00325 shared_ptr<diffbifn> phxu;
00326 public:
00328 BM* _copy_() const {
00329 EKFCh* E = new EKFCh;
00330 E->set_parameters (pfxu, phxu, Q, R);
00331 E->set_statistics (est->_mu(), est->_R());
00332 return E;
00333 }
00335 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00336
00338 void bayes (const vec &dt);
00339
00340 void from_setting (const Setting &set);
00341
00342
00343
00344 };
00345
00346 UIREGISTER (EKFCh);
00347 SHAREDPTR (EKFCh);
00348
00349
00351
00359 class MultiModel: public BM
00360 {
00361 protected:
00363 Array<EKFCh*> Models;
00365 vec w;
00367 vec _lls;
00369 int policy;
00371 enorm<chmat> est;
00372 public:
00373 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00374 Models = A;
00375 w.set_length (A.length());
00376 _lls.set_length (A.length());
00377 policy = pol0;
00378
00379 est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00380 est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00381 }
00382 void bayes (const vec &dt) {
00383 int n = Models.length();
00384 int i;
00385 for (i = 0; i < n; i++) {
00386 Models (i)->bayes (dt);
00387 _lls (i) = Models (i)->_ll();
00388 }
00389 double mlls = max (_lls);
00390 w = exp (_lls - mlls);
00391 w /= sum (w);
00392
00393 switch (policy) {
00394 case 1: {
00395 int mi = max_index (w);
00396 const enorm<chmat> &st = Models (mi)->posterior() ;
00397 est.set_parameters (st.mean(), st._R());
00398 }
00399 break;
00400 default:
00401 bdm_error ("unknown policy");
00402 }
00403
00404 for (i = 0; i < n; i++) {
00405 Models (i)->set_statistics (est.mean(), est._R());
00406 }
00407 }
00409 const enorm<chmat>& posterior() const {
00410 return est;
00411 }
00412
00413 void from_setting (const Setting &set);
00414
00415
00416
00417 };
00418
00419 UIREGISTER (MultiModel);
00420 SHAREDPTR (MultiModel);
00421
00423
00431
00432 class StateCanonical: public StateSpace<fsqmat>{
00433 protected:
00435 datalink_part th2A;
00437 datalink_part th2C;
00439 datalink_part th2D;
00441 vec A1row;
00443 vec C1row;
00445 vec D1row;
00446
00447 public:
00449 void connect_mlnorm(const mlnorm<fsqmat> &ml){
00450
00451 const RV &yrv = ml._rv();
00452
00453 RV rgr0 = ml._rvc().remove_time();
00454 RV urv = rgr0.subt(yrv);
00455
00456
00457 bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
00458
00459
00460 RV xrv;
00461 RV Crv;
00462 int td=ml._rvc().mint();
00463
00464 for (int t=-1;t>=td;t--){
00465 xrv.add(yrv.copy_t(t));
00466 Crv.add(urv.copy_t(t));
00467 }
00468
00469 this->dimx = xrv._dsize();
00470 this->dimy = yrv._dsize();
00471 this->dimu = urv._dsize();
00472
00473
00474 th2A.set_connection(xrv, ml._rvc());
00475 th2C.set_connection(Crv, ml._rvc());
00476 th2D.set_connection(urv, ml._rvc());
00477
00478
00479 this->A=zeros(dimx,dimx);
00480 for (int j=1; j<dimx; j++){A(j,j-1)=1.0;}
00481 this->B=zeros(dimx,1);
00482 this->B(0) = 1.0;
00483 this->C=zeros(1,dimx);
00484 this->D=zeros(1,urv._dsize());
00485 this->Q = zeros(dimx,dimx);
00486
00487
00488
00489 this->A1row = zeros(xrv._dsize());
00490 this->C1row = zeros(xrv._dsize());
00491 this->D1row = zeros(urv._dsize());
00492
00493 update_from(ml);
00494 validate();
00495 };
00497 void update_from(const mlnorm<fsqmat> &ml){
00498
00499 vec theta = ml._A().get_row(0);
00500
00501 th2A.filldown(theta,A1row);
00502 th2C.filldown(theta,C1row);
00503 th2D.filldown(theta,D1row);
00504
00505 R = ml._R();
00506
00507 A.set_row(0,A1row);
00508 C.set_row(0,C1row+D1row*A1row);
00509 D.set_row(0,D1row);
00510
00511 }
00512 };
00513
00515
00516 template<class sq_T>
00517 void StateSpace<sq_T>::set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0)
00518 {
00519
00520 A = A0;
00521 B = B0;
00522 C = C0;
00523 D = D0;
00524 R = R0;
00525 Q = Q0;
00526 validate();
00527 }
00528
00529 template<class sq_T>
00530 void StateSpace<sq_T>::validate(){
00531 dimx = A.rows();
00532 dimu = B.cols();
00533 dimy = C.rows();
00534 bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
00535 bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
00536 bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
00537 bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
00538 bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00539 bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
00540 }
00541
00542 }
00543 #endif // KF_H
00544
00545