00001 
00013 #ifndef KF_H
00014 #define KF_H
00015 
00016 
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021 
00022 namespace bdm
00023 {
00024 
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035         protected:
00037                 int dimx;
00039                 int dimy;
00041                 int dimu;
00043                 mat A;
00045                 mat B;
00047                 mat C;
00049                 mat D;
00051                 sq_T Q;
00053                 sq_T R;
00054         public:
00055                 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056                 StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
00057                 void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
00058                 void validate();
00060                 void from_setting (const Setting &set) {
00061                         UI::get (A, set, "A", UI::compulsory);
00062                         UI::get (B, set, "B", UI::compulsory);
00063                         UI::get (C, set, "C", UI::compulsory);
00064                         UI::get (D, set, "D", UI::compulsory);
00065                         mat Qtm, Rtm;
00066                         if(!UI::get(Qtm, set, "Q", UI::optional)){
00067                                 vec dq;
00068                                 UI::get(dq, set, "dQ", UI::compulsory);
00069                                 Qtm=diag(dq);
00070                         }
00071                         if(!UI::get(Rtm, set, "R", UI::optional)){
00072                                 vec dr;
00073                                 UI::get(dr, set, "dQ", UI::compulsory);
00074                                 Rtm=diag(dr);
00075                         }
00076                         R=Rtm; 
00077                         Q=Qtm; 
00078                         
00079                         validate();
00080                 }               
00082                 int _dimx(){return dimx;}
00084                 int _dimy(){return dimy;}
00086                 int _dimu(){return dimu;}
00088                 const mat& _A() const {return A;}
00090                 const mat& _B()const {return B;}
00092                 const mat& _C()const {return C;}
00094                 const mat& _D()const {return D;}
00096                 const sq_T& _Q()const {return Q;}
00098                 const sq_T& _R()const {return R;}
00099 };
00100 
00102 template<class sq_T>
00103 class Kalman: public BM, public StateSpace<sq_T>
00104 {
00105         protected:
00107                 RV yrv;
00109                 RV urv;
00111                 mat  _K;
00113                 shared_ptr<enorm<sq_T> > est;
00115                 enorm<sq_T>  fy;
00116         public:
00117                 Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
00118                 Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K),  est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
00119                 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00120                 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00122                 const enorm<sq_T>& posterior() const {return *est.get();}
00124                 shared_ptr<epdf> shared_posterior() {return est;}
00126                 void from_setting (const Setting &set) {
00127                         StateSpace<sq_T>::from_setting(set);
00128                                                 
00129                         mat P0; vec mu0;
00130                         UI::get(mu0, set, "mu0", UI::optional);
00131                         UI::get(P0, set,  "P0", UI::optional);
00132                         set_statistics(mu0,P0);
00133                         
00134                         UI::get (yrv, set, "yrv", UI::optional);
00135                         UI::get (urv, set, "urv", UI::optional);
00136                         set_drv(concat(yrv,urv));
00137                         
00138                         validate();
00139                 }
00140                 void validate() {
00141                         StateSpace<sq_T>::validate();
00142                         bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
00143                 }
00144 };
00149 class KalmanFull : public Kalman<fsqmat>
00150 {
00151         public:
00153                 KalmanFull() :Kalman<fsqmat>(){};
00155                 void bayes (const vec &dt);
00156                 BM* _copy_() const {
00157                         KalmanFull* K = new KalmanFull;
00158                         K->set_parameters (A, B, C, D, Q, R);
00159                         K->set_statistics (est->_mu(), est->_R());
00160                         return K;
00161                 }
00162 };
00163 UIREGISTER(KalmanFull);
00164 
00165 
00173 class KalmanCh : public Kalman<chmat>
00174 {
00175         protected:
00178                 mat preA;
00180                 mat postA;
00182         public:
00184                 BM* _copy_() const {
00185                         KalmanCh* K = new KalmanCh;
00186                         K->set_parameters (A, B, C, D, Q, R);
00187                         K->set_statistics (est->_mu(), est->_R());
00188                         return K;
00189                 }
00191                 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00193                 void initialize();
00194 
00208                 void bayes (const vec &dt);
00209                 
00210                 void from_setting(const Setting &set){
00211                         Kalman<chmat>::from_setting(set);
00212                         initialize();
00213                 }
00214 };
00215 UIREGISTER(KalmanCh);
00216 
00222 class EKFfull : public KalmanFull
00223 {
00224         protected:
00226                 shared_ptr<diffbifn> pfxu;
00227 
00229                 shared_ptr<diffbifn> phxu;
00230 
00231         public:
00233                 EKFfull ();
00234 
00236                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00237 
00239                 void bayes (const vec &dt);
00241                 void set_statistics (const vec &mu0, const mat &P0) {
00242                         est->set_parameters (mu0, P0);
00243                 };
00244                 const mat _R() {
00245                         return est->_R().to_mat();
00246                 }
00247                 void from_setting (const Setting &set) {
00248                         shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
00249                         shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
00250                         
00251                         
00252                         int dim = IM->dimension();
00253                         vec mu0;
00254                         if ( !UI::get ( mu0, set, "mu0" ) )
00255                                 mu0 = zeros ( dim );
00256                         
00257                         mat P0;
00258                         vec dP0;
00259                         if ( UI::get ( dP0, set, "dP0" ) )
00260                                 P0 = diag ( dP0 );
00261                         else if ( !UI::get ( P0, set, "P0" ) )
00262                                 P0 = eye ( dim );
00263                         
00264                         set_statistics ( mu0, P0 );
00265                         
00266                         
00267                         vec dQ, dR;
00268                         UI::get ( dQ, set, "dQ", UI::compulsory );
00269                         UI::get ( dR, set, "dR", UI::compulsory );
00270                         set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
00271                         
00272                         
00273                         shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
00274                         set_drv ( *drv );
00275                         shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
00276                         set_rv ( *rv );
00277                         
00278                         string options;
00279                         if ( UI::get ( options, set, "options" ) )
00280                                 set_options ( options );
00281 
00282 
00283 
00284 
00285 
00286 
00287 
00288 
00289 
00290 
00291 
00292 
00293 
00294 
00295 
00296 
00297 
00298 
00299 
00300 
00301 
00302 
00303                         validate();
00304                 }
00305                 void validate() {
00306                         
00307                 }
00308 };
00309 UIREGISTER(EKFfull);
00310 
00311 
00318 class EKFCh : public KalmanCh
00319 {
00320         protected:
00322                 shared_ptr<diffbifn> pfxu;
00323 
00325                 shared_ptr<diffbifn> phxu;
00326         public:
00328                 BM* _copy_() const {
00329                         EKFCh* E = new EKFCh;
00330                         E->set_parameters (pfxu, phxu, Q, R);
00331                         E->set_statistics (est->_mu(), est->_R());
00332                         return E;
00333                 }
00335                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00336 
00338                 void bayes (const vec &dt);
00339 
00340                 void from_setting (const Setting &set);
00341 
00342                 
00343 
00344 };
00345 
00346 UIREGISTER (EKFCh);
00347 SHAREDPTR (EKFCh);
00348 
00349 
00351 
00359 class MultiModel: public BM
00360 {
00361         protected:
00363                 Array<EKFCh*> Models;
00365                 vec w;
00367                 vec _lls;
00369                 int policy;
00371                 enorm<chmat> est;
00372         public:
00373                 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00374                         Models = A;
00375                         w.set_length (A.length());
00376                         _lls.set_length (A.length());
00377                         policy = pol0;
00378 
00379                         est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00380                         est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00381                 }
00382                 void bayes (const vec &dt) {
00383                         int n = Models.length();
00384                         int i;
00385                         for (i = 0; i < n; i++) {
00386                                 Models (i)->bayes (dt);
00387                                 _lls (i) = Models (i)->_ll();
00388                         }
00389                         double mlls = max (_lls);
00390                         w = exp (_lls - mlls);
00391                         w /= sum (w);   
00392                         
00393                         switch (policy) {
00394                                 case 1: {
00395                                         int mi = max_index (w);
00396                                         const enorm<chmat> &st = Models (mi)->posterior() ;
00397                                         est.set_parameters (st.mean(), st._R());
00398                                 }
00399                                 break;
00400                                 default:
00401                                         bdm_error ("unknown policy");
00402                         }
00403                         
00404                         for (i = 0; i < n; i++) {
00405                                 Models (i)->set_statistics (est.mean(), est._R());
00406                         }
00407                 }
00409                 const enorm<chmat>& posterior() const {
00410                         return est;
00411                 }
00412 
00413                 void from_setting (const Setting &set);
00414 
00415                 
00416 
00417 };
00418 
00419 UIREGISTER (MultiModel);
00420 SHAREDPTR (MultiModel);
00421 
00423 
00431 
00432 class StateCanonical: public StateSpace<fsqmat>{
00433         protected:
00435                 datalink_part th2A;
00437                 datalink_part th2C;
00439                 datalink_part th2D;
00441                 vec A1row;
00443                 vec C1row;
00445                 vec D1row;
00446                 
00447         public:
00449                 void connect_mlnorm(const mlnorm<fsqmat> &ml){
00450                         
00451                         const RV &yrv = ml._rv();
00452                         
00453                         RV rgr0 = ml._rvc().remove_time();
00454                         RV urv = rgr0.subt(yrv); 
00455                         
00456                         
00457                         bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
00458 
00459                         
00460                         RV xrv; 
00461                         RV Crv; 
00462                         int td=ml._rvc().mint();
00463                         
00464                         for (int t=-1;t>=td;t--){
00465                                 xrv.add(yrv.copy_t(t));
00466                                 Crv.add(urv.copy_t(t));
00467                         }
00468                         
00469                         this->dimx = xrv._dsize();
00470                         this->dimy = yrv._dsize();
00471                         this->dimu = urv._dsize();
00472                         
00473                         
00474                         th2A.set_connection(xrv, ml._rvc());
00475                         th2C.set_connection(Crv, ml._rvc());
00476                         th2D.set_connection(urv, ml._rvc());
00477 
00478                         
00479                         this->A=zeros(dimx,dimx);
00480                         for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} 
00481                                 this->B=zeros(dimx,1);
00482                                 this->B(0) = 1.0;
00483                                 this->C=zeros(1,dimx);
00484                                 this->D=zeros(1,urv._dsize());
00485                                 this->Q = zeros(dimx,dimx);
00486                         
00487                         
00488                         
00489                         this->A1row = zeros(xrv._dsize());
00490                         this->C1row = zeros(xrv._dsize());
00491                         this->D1row = zeros(urv._dsize());
00492                         
00493                         update_from(ml);
00494                         validate();
00495                 };
00497                 void update_from(const mlnorm<fsqmat> &ml){
00498                         
00499                         vec theta = ml._A().get_row(0); 
00500                         
00501                         th2A.filldown(theta,A1row);
00502                         th2C.filldown(theta,C1row);
00503                         th2D.filldown(theta,D1row);
00504 
00505                         R = ml._R();
00506 
00507                         A.set_row(0,A1row);
00508                         C.set_row(0,C1row+D1row*A1row);
00509                         D.set_row(0,D1row);
00510                         
00511                 }
00512 };
00513 
00515 
00516 template<class sq_T>
00517 void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
00518 {
00519 
00520         A = A0;
00521         B = B0;
00522         C = C0;
00523         D = D0;
00524         R = R0;
00525         Q = Q0;
00526         validate();
00527 }
00528 
00529 template<class sq_T>
00530 void StateSpace<sq_T>::validate(){
00531         dimx = A.rows();
00532         dimu = B.cols();
00533         dimy = C.rows();
00534         bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
00535         bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
00536         bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
00537         bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
00538         bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00539         bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
00540 }
00541 
00542 }
00543 #endif // KF_H
00544 
00545