00001
00013 #ifndef KF_H
00014 #define KF_H
00015
00016
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021
00022 namespace bdm
00023 {
00024
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035 protected:
00037 int dimx;
00039 int dimy;
00041 int dimu;
00043 mat A;
00045 mat B;
00047 mat C;
00049 mat D;
00051 sq_T Q;
00053 sq_T R;
00054 public:
00055 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0);
00057 void validate();
00059 void from_setting (const Setting &set) {
00060 UI::get (A, set, "A", UI::compulsory);
00061 UI::get (B, set, "B", UI::compulsory);
00062 UI::get (C, set, "C", UI::compulsory);
00063 UI::get (D, set, "D", UI::compulsory);
00064 mat Qtm, Rtm;
00065 if(!UI::get(Qtm, set, "Q", UI::optional)){
00066 vec dq;
00067 UI::get(dq, set, "dQ", UI::compulsory);
00068 Qtm=diag(dq);
00069 }
00070 if(!UI::get(Rtm, set, "R", UI::optional)){
00071 vec dr;
00072 UI::get(dr, set, "dQ", UI::compulsory);
00073 Rtm=diag(dr);
00074 }
00075 R=Rtm;
00076 Q=Qtm;
00077
00078 validate();
00079 }
00081 int _dimx(){return dimx;}
00083 int _dimy(){return dimy;}
00085 int _dimu(){return dimu;}
00087 const mat& _A() const {return A;}
00089 const mat& _B()const {return B;}
00091 const mat& _C()const {return C;}
00093 const mat& _D()const {return D;}
00095 const sq_T& _Q()const {return Q;}
00097 const sq_T& _R()const {return R;}
00098 };
00099
00101 template<class sq_T>
00102 class Kalman: public BM, public StateSpace<sq_T>
00103 {
00104 protected:
00106 RV yrv;
00108 RV urv;
00110 mat _K;
00112 shared_ptr<enorm<sq_T> > est;
00114 enorm<sq_T> fy;
00115 public:
00116 Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(), est(new enorm<sq_T>){}
00117 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00118 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00120 const enorm<sq_T>& posterior() const {return *est.get();}
00122 shared_ptr<epdf> shared_posterior() {return est;}
00124 void from_setting (const Setting &set) {
00125 StateSpace<sq_T>::from_setting(set);
00126
00127 mat P0; vec mu0;
00128 UI::get(mu0, set, "mu0", UI::optional);
00129 UI::get(P0, set, "P0", UI::optional);
00130 set_statistics(mu0,P0);
00131
00132 UI::get (yrv, set, "yrv", UI::optional);
00133 UI::get (urv, set, "urv", UI::optional);
00134 set_drv(concat(yrv,urv));
00135
00136 validate();
00137 }
00138 void validate() {
00139 StateSpace<sq_T>::validate();
00140 bdm_assert_debug(est->dimension(), "Statistics and model parameters mismatch");
00141 }
00142 };
00147 class KalmanFull : public Kalman<fsqmat>
00148 {
00149 public:
00151 KalmanFull() :Kalman<fsqmat>(){};
00153 void bayes (const vec &dt);
00154 };
00155 UIREGISTER(KalmanFull);
00156
00157
00165 class KalmanCh : public Kalman<chmat>
00166 {
00167 protected:
00170 mat preA;
00172 mat postA;
00174 public:
00176 BM* _copy_() const {
00177 KalmanCh* K = new KalmanCh;
00178 K->set_parameters (A, B, C, D, Q, R);
00179 K->set_statistics (est->_mu(), est->_R());
00180 return K;
00181 }
00183 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00185 void initialize();
00186
00200 void bayes (const vec &dt);
00201
00202 void from_setting(const Setting &set){
00203 Kalman<chmat>::from_setting(set);
00204 initialize();
00205 }
00206 };
00207 UIREGISTER(KalmanCh);
00208
00214 class EKFfull : public KalmanFull
00215 {
00216 protected:
00218 shared_ptr<diffbifn> pfxu;
00219
00221 shared_ptr<diffbifn> phxu;
00222
00223 public:
00225 EKFfull ();
00226
00228 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00229
00231 void bayes (const vec &dt);
00233 void set_statistics (const vec &mu0, const mat &P0) {
00234 est->set_parameters (mu0, P0);
00235 };
00236 const mat _R() {
00237 return est->_R().to_mat();
00238 }
00239 };
00240 UIREGISTER(EKFfull);
00241
00242
00249 class EKFCh : public KalmanCh
00250 {
00251 protected:
00253 shared_ptr<diffbifn> pfxu;
00254
00256 shared_ptr<diffbifn> phxu;
00257 public:
00259 BM* _copy_() const {
00260 EKFCh* E = new EKFCh;
00261 E->set_parameters (pfxu, phxu, Q, R);
00262 E->set_statistics (est->_mu(), est->_R());
00263 return E;
00264 }
00266 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00267
00269 void bayes (const vec &dt);
00270
00271 void from_setting (const Setting &set);
00272
00273
00274
00275 };
00276
00277 UIREGISTER (EKFCh);
00278 SHAREDPTR (EKFCh);
00279
00280
00282
00290 class MultiModel: public BM
00291 {
00292 protected:
00294 Array<EKFCh*> Models;
00296 vec w;
00298 vec _lls;
00300 int policy;
00302 enorm<chmat> est;
00303 public:
00304 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00305 Models = A;
00306 w.set_length (A.length());
00307 _lls.set_length (A.length());
00308 policy = pol0;
00309
00310 est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00311 est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00312 }
00313 void bayes (const vec &dt) {
00314 int n = Models.length();
00315 int i;
00316 for (i = 0; i < n; i++) {
00317 Models (i)->bayes (dt);
00318 _lls (i) = Models (i)->_ll();
00319 }
00320 double mlls = max (_lls);
00321 w = exp (_lls - mlls);
00322 w /= sum (w);
00323
00324 switch (policy) {
00325 case 1: {
00326 int mi = max_index (w);
00327 const enorm<chmat> &st = Models (mi)->posterior() ;
00328 est.set_parameters (st.mean(), st._R());
00329 }
00330 break;
00331 default:
00332 bdm_error ("unknown policy");
00333 }
00334
00335 for (i = 0; i < n; i++) {
00336 Models (i)->set_statistics (est.mean(), est._R());
00337 }
00338 }
00340 const enorm<chmat>& posterior() const {
00341 return est;
00342 }
00343
00344 void from_setting (const Setting &set);
00345
00346
00347
00348 };
00349
00350 UIREGISTER (MultiModel);
00351 SHAREDPTR (MultiModel);
00352
00354
00362
00363 class StateCanonical: public StateSpace<fsqmat>{
00364 protected:
00366 datalink_part th2A;
00368 datalink_part th2C;
00370 datalink_part th2D;
00372 vec A1row;
00374 vec C1row;
00376 vec D1row;
00377
00378 public:
00380 void connect_mlnorm(const mlnorm<fsqmat > &ml){
00381
00382 const RV &yrv = ml._rv();
00383
00384 RV rgr0 = ml._rvc().remove_time();
00385 RV urv = rgr0.subt(yrv);
00386
00387
00388 bdm_assert_debug(yrv._dsize()==1, "Only for SISO so far..." );
00389
00390
00391 RV xrv;
00392 RV Crv;
00393 int td=ml._rvc().mint();
00394
00395 for (int t=-1;t>=td;t--){
00396 xrv.add(yrv.copy_t(t));
00397 Crv.add(urv.copy_t(t));
00398 }
00399
00400 this->dimx = xrv._dsize();
00401 this->dimy = yrv._dsize();
00402 this->dimu = urv._dsize();
00403
00404
00405 th2A.set_connection(xrv, ml._rvc());
00406 th2C.set_connection(Crv, ml._rvc());
00407 th2D.set_connection(urv, ml._rvc());
00408
00409
00410 this->A=zeros(dimx,dimx);
00411 for (int j=1; j<dimx; j++){A(j,j-1)=1.0;}
00412 this->B=zeros(dimx,1);
00413 this->B(0) = 1.0;
00414 this->C=zeros(1,dimx);
00415 this->D=zeros(1,urv._dsize());
00416 this->Q = zeros(dimx,dimx);
00417
00418
00419
00420 this->A1row = zeros(xrv._dsize());
00421 this->C1row = zeros(xrv._dsize());
00422 this->D1row = zeros(urv._dsize());
00423
00424 update_from(ml);
00425 validate();
00426 };
00428 void update_from(const mlnorm<fsqmat> &ml){
00429
00430 vec theta = ml._A().get_row(0);
00431
00432 th2A.filldown(theta,A1row);
00433 th2C.filldown(theta,C1row);
00434 th2D.filldown(theta,D1row);
00435
00436 R = ml._R();
00437
00438 A.set_row(0,A1row);
00439 C.set_row(0,C1row+D1row*A1row);
00440 D.set_row(0,D1row);
00441
00442 }
00443 };
00444
00446
00447 template<class sq_T>
00448 void StateSpace<sq_T>::set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0)
00449 {
00450 dimx = A0.rows();
00451 dimu = B0.cols();
00452 dimy = C0.rows();
00453
00454 A = A0;
00455 B = B0;
00456 C = C0;
00457 D = D0;
00458 R = R0;
00459 Q = Q0;
00460 validate();
00461 }
00462
00463 template<class sq_T>
00464 void StateSpace<sq_T>::validate(){
00465 bdm_assert_debug (A.cols() == dimx, "KalmanFull: A is not square");
00466 bdm_assert_debug (B.rows() == dimx, "KalmanFull: B is not compatible");
00467 bdm_assert_debug (C.cols() == dimx, "KalmanFull: C is not square");
00468 bdm_assert_debug ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
00469 bdm_assert_debug ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00470 bdm_assert_debug ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
00471 }
00472
00473 }
00474 #endif // KF_H
00475
00476