00001 
00013 #ifndef KF_H
00014 #define KF_H
00015 
00016 
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021 
00022 namespace bdm
00023 {
00024 
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035         protected:
00037                 int dimx;
00039                 int dimy;
00041                 int dimu;
00043                 mat A;
00045                 mat B;
00047                 mat C;
00049                 mat D;
00051                 sq_T Q;
00053                 sq_T R;
00054         public:
00055                 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056                 void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
00057                 void validate();
00059                 void from_setting (const Setting &set) {
00060                         UI::get (A, set, "A", UI::compulsory);
00061                         UI::get (B, set, "B", UI::compulsory);
00062                         UI::get (C, set, "C", UI::compulsory);
00063                         UI::get (D, set, "D", UI::compulsory);
00064                         mat Qtm, Rtm;
00065                         if(!UI::get(Qtm, set, "Q", UI::optional)){
00066                                 vec dq;
00067                                 UI::get(dq, set, "dQ", UI::compulsory);
00068                                 Qtm=diag(dq);
00069                         }
00070                         if(!UI::get(Rtm, set, "R", UI::optional)){
00071                                 vec dr;
00072                                 UI::get(dr, set, "dQ", UI::compulsory);
00073                                 Rtm=diag(dr);
00074                         }
00075                         R=Rtm; 
00076                         Q=Qtm; 
00077                         
00078                         validate();
00079                 }               
00081                 int _dimx(){return dimx;}
00083                 int _dimy(){return dimy;}
00085                 int _dimu(){return dimu;}
00087                 const mat& _A() const {return A;}
00089                 const mat& _B()const {return B;}
00091                 const mat& _C()const {return C;}
00093                 const mat& _D()const {return D;}
00095                 const sq_T& _Q()const {return Q;}
00097                 const sq_T& _R()const {return R;}
00098 };
00099 
00101 template<class sq_T>
00102 class Kalman: public BM, public StateSpace<sq_T>
00103 {
00104         protected:
00106                 RV yrv;
00108                 RV urv;
00110                 mat  _K;
00112                 shared_ptr<enorm<sq_T> > est;
00114                 enorm<sq_T>  fy;
00115         public:
00116                 Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
00117                 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00118                 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00120                 const enorm<sq_T>& posterior() const {return *est.get();}
00122                 shared_ptr<epdf> shared_posterior() {return est;}
00124                 void from_setting (const Setting &set) {
00125                         StateSpace<sq_T>::from_setting(set);
00126                                                 
00127                         mat P0; vec mu0;
00128                         UI::get(mu0, set, "mu0", UI::optional);
00129                         UI::get(P0, set,  "P0", UI::optional);
00130                         set_statistics(mu0,P0);
00131                         
00132                         UI::get (yrv, set, "yrv", UI::optional);
00133                         UI::get (urv, set, "urv", UI::optional);
00134                         set_drv(concat(yrv,urv));
00135                         
00136                         validate();
00137                 }
00138                 void validate() {
00139                         StateSpace<sq_T>::validate();
00140                         bdm_assert_debug(est->dimension(), "Statistics and model parameters mismatch");
00141                 }
00142 };
00147 class KalmanFull : public Kalman<fsqmat>
00148 {
00149         public:
00151                 KalmanFull() :Kalman<fsqmat>(){};
00153                 void bayes (const vec &dt);
00154 };
00155 UIREGISTER(KalmanFull);
00156 
00157 
00165 class KalmanCh : public Kalman<chmat>
00166 {
00167         protected:
00170                 mat preA;
00172                 mat postA;
00174         public:
00176                 BM* _copy_() const {
00177                         KalmanCh* K = new KalmanCh;
00178                         K->set_parameters (A, B, C, D, Q, R);
00179                         K->set_statistics (est->_mu(), est->_R());
00180                         return K;
00181                 }
00183                 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00185                 void initialize();
00186 
00200                 void bayes (const vec &dt);
00201                 
00202                 void from_setting(const Setting &set){
00203                         Kalman<chmat>::from_setting(set);
00204                         initialize();
00205                 }
00206 };
00207 UIREGISTER(KalmanCh);
00208 
00214 class EKFfull : public KalmanFull
00215 {
00216         protected:
00218                 shared_ptr<diffbifn> pfxu;
00219 
00221                 shared_ptr<diffbifn> phxu;
00222 
00223         public:
00225                 EKFfull ();
00226 
00228                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00229 
00231                 void bayes (const vec &dt);
00233                 void set_statistics (const vec &mu0, const mat &P0) {
00234                         est->set_parameters (mu0, P0);
00235                 };
00236                 const mat _R() {
00237                         return est->_R().to_mat();
00238                 }
00239 };
00240 UIREGISTER(EKFfull);
00241 
00242 
00249 class EKFCh : public KalmanCh
00250 {
00251         protected:
00253                 shared_ptr<diffbifn> pfxu;
00254 
00256                 shared_ptr<diffbifn> phxu;
00257         public:
00259                 BM* _copy_() const {
00260                         EKFCh* E = new EKFCh;
00261                         E->set_parameters (pfxu, phxu, Q, R);
00262                         E->set_statistics (est->_mu(), est->_R());
00263                         return E;
00264                 }
00266                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00267 
00269                 void bayes (const vec &dt);
00270 
00271                 void from_setting (const Setting &set);
00272 
00273                 
00274 
00275 };
00276 
00277 UIREGISTER (EKFCh);
00278 SHAREDPTR (EKFCh);
00279 
00280 
00282 
00290 class MultiModel: public BM
00291 {
00292         protected:
00294                 Array<EKFCh*> Models;
00296                 vec w;
00298                 vec _lls;
00300                 int policy;
00302                 enorm<chmat> est;
00303         public:
00304                 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00305                         Models = A;
00306                         w.set_length (A.length());
00307                         _lls.set_length (A.length());
00308                         policy = pol0;
00309 
00310                         est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00311                         est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00312                 }
00313                 void bayes (const vec &dt) {
00314                         int n = Models.length();
00315                         int i;
00316                         for (i = 0; i < n; i++) {
00317                                 Models (i)->bayes (dt);
00318                                 _lls (i) = Models (i)->_ll();
00319                         }
00320                         double mlls = max (_lls);
00321                         w = exp (_lls - mlls);
00322                         w /= sum (w);   
00323                         
00324                         switch (policy) {
00325                                 case 1: {
00326                                         int mi = max_index (w);
00327                                         const enorm<chmat> &st = Models (mi)->posterior() ;
00328                                         est.set_parameters (st.mean(), st._R());
00329                                 }
00330                                 break;
00331                                 default:
00332                                         bdm_error ("unknown policy");
00333                         }
00334                         
00335                         for (i = 0; i < n; i++) {
00336                                 Models (i)->set_statistics (est.mean(), est._R());
00337                         }
00338                 }
00340                 const enorm<chmat>& posterior() const {
00341                         return est;
00342                 }
00343 
00344                 void from_setting (const Setting &set);
00345 
00346                 
00347 
00348 };
00349 
00350 UIREGISTER (MultiModel);
00351 SHAREDPTR (MultiModel);
00352 
00354 
00362 
00363 class StateCanonical: public StateSpace<fsqmat>{
00364         protected:
00366                 datalink_part th2A;
00368                 datalink_part th2C;
00370                 datalink_part th2D;
00372                 vec A1row;
00374                 vec C1row;
00376                 vec D1row;
00377                 
00378         public:
00380                 void connect_mlnorm(const mlnorm<fsqmat > &ml){
00381         
00382                         const RV &yrv = ml._rv();
00383                         
00384                         RV rgr0 = ml._rvc().remove_time();
00385                         RV urv = rgr0.subt(yrv); 
00386                         
00387                         
00388                         bdm_assert_debug(yrv._dsize()==1, "Only for SISO so far..." );
00389 
00390                         
00391                         RV xrv; 
00392                         RV Crv; 
00393                         int td=ml._rvc().mint();
00394                         
00395                         for (int t=-1;t>=td;t--){
00396                                 xrv.add(yrv.copy_t(t));
00397                                 Crv.add(urv.copy_t(t));
00398                         }
00399                         
00400                         this->dimx = xrv._dsize();
00401                         this->dimy = yrv._dsize();
00402                         this->dimu = urv._dsize();
00403                         
00404                         
00405                         th2A.set_connection(xrv, ml._rvc());
00406                         th2C.set_connection(Crv, ml._rvc());
00407                         th2D.set_connection(urv, ml._rvc());
00408 
00409                         
00410                         this->A=zeros(dimx,dimx);
00411                         for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} 
00412                                 this->B=zeros(dimx,1);
00413                                 this->B(0) = 1.0;
00414                                 this->C=zeros(1,dimx);
00415                                 this->D=zeros(1,urv._dsize());
00416                                 this->Q = zeros(dimx,dimx);
00417                         
00418                         
00419                         
00420                         this->A1row = zeros(xrv._dsize());
00421                         this->C1row = zeros(xrv._dsize());
00422                         this->D1row = zeros(urv._dsize());
00423                         
00424                         update_from(ml);
00425                         validate();
00426                 };
00428                 void update_from(const mlnorm<fsqmat> &ml){
00429                         
00430                         vec theta = ml._A().get_row(0); 
00431                         
00432                         th2A.filldown(theta,A1row);
00433                         th2C.filldown(theta,C1row);
00434                         th2D.filldown(theta,D1row);
00435 
00436                         R = ml._R();
00437 
00438                         A.set_row(0,A1row);
00439                         C.set_row(0,C1row+D1row*A1row);
00440                         D.set_row(0,D1row);
00441                         
00442                 }
00443 };
00444 
00446 
00447 template<class sq_T>
00448 void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
00449 {
00450         dimx = A0.rows();
00451         dimu = B0.cols();
00452         dimy = C0.rows();
00453 
00454         A = A0;
00455         B = B0;
00456         C = C0;
00457         D = D0;
00458         R = R0;
00459         Q = Q0;
00460         validate();
00461 }
00462 
00463 template<class sq_T>
00464 void StateSpace<sq_T>::validate(){
00465         bdm_assert_debug (A.cols() == dimx, "KalmanFull: A is not square");
00466         bdm_assert_debug (B.rows() == dimx, "KalmanFull: B is not compatible");
00467         bdm_assert_debug (C.cols() == dimx, "KalmanFull: C is not square");
00468         bdm_assert_debug ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
00469         bdm_assert_debug ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00470         bdm_assert_debug ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
00471 }
00472 
00473 }
00474 #endif // KF_H
00475 
00476