00001 
00013 #ifndef KF_H
00014 #define KF_H
00015 
00016 
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../math/chmat.h"
00020 #include "../base/user_info.h"
00021 
00022 namespace bdm
00023 {
00024 
00032 template<class sq_T>
00033 class StateSpace
00034 {
00035         protected:
00037                 int dimx;
00039                 int dimy;
00041                 int dimu;
00043                 mat A;
00045                 mat B;
00047                 mat C;
00049                 mat D;
00051                 sq_T Q;
00053                 sq_T R;
00054         public:
00055                 StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
00056                 void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
00057                 void validate();
00059                 void from_setting (const Setting &set) {
00060                         UI::get (A, set, "A", UI::compulsory);
00061                         UI::get (B, set, "B", UI::compulsory);
00062                         UI::get (C, set, "C", UI::compulsory);
00063                         UI::get (D, set, "D", UI::compulsory);
00064                         mat Qtm, Rtm;
00065                         if(!UI::get(Qtm, set, "Q", UI::optional)){
00066                                 vec dq;
00067                                 UI::get(dq, set, "dQ", UI::compulsory);
00068                                 Qtm=diag(dq);
00069                         }
00070                         if(!UI::get(Rtm, set, "R", UI::optional)){
00071                                 vec dr;
00072                                 UI::get(dr, set, "dQ", UI::compulsory);
00073                                 Rtm=diag(dr);
00074                         }
00075                         R=Rtm; 
00076                         Q=Qtm; 
00077                         
00078                         validate();
00079                 }               
00081                 int _dimx(){return dimx;}
00083                 int _dimy(){return dimy;}
00085                 int _dimu(){return dimu;}
00087                 mat _A(){return A;}
00089                 mat _B(){return B;}
00091                 mat _C(){return C;}
00093                 mat _D(){return D;}
00095                 sq_T _Q(){return Q;}
00097                 sq_T _R(){return R;}
00098 };
00099 
00101 template<class sq_T>
00102 class Kalman: public BM, public StateSpace<sq_T>
00103 {
00104         protected:
00106                 RV yrv;
00108                 RV urv;
00110                 mat  _K;
00112                 shared_ptr<enorm<sq_T> > est;
00114                 enorm<sq_T>  fy;
00115         public:
00116                 Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
00117                 void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
00118                 void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
00120                 const enorm<sq_T>& posterior() const {return *est.get();}
00122                 shared_ptr<epdf> shared_posterior() {return est;}
00124                 void from_setting (const Setting &set) {
00125                         StateSpace<sq_T>::from_setting(set);
00126                                                 
00127                         mat P0; vec mu0;
00128                         UI::get(mu0, set, "mu0", UI::optional);
00129                         UI::get(P0, set,  "P0", UI::optional);
00130                         set_statistics(mu0,P0);
00131                         
00132                         UI::get (yrv, set, "yrv", UI::optional);
00133                         UI::get (urv, set, "urv", UI::optional);
00134                         set_drv(concat(yrv,urv));
00135                         
00136                         validate();
00137                 }
00138                 void validate() {
00139                         StateSpace<sq_T>::validate();
00140                         bdm_assert_debug(est->dimension(), "Statistics and model parameters mismatch");
00141                 }
00142 };
00147 class KalmanFull : public Kalman<fsqmat>
00148 {
00149         public:
00151                 KalmanFull() :Kalman<fsqmat>(){};
00153                 void bayes (const vec &dt);
00154 };
00155 UIREGISTER(KalmanFull);
00156 
00157 
00165 class KalmanCh : public Kalman<chmat>
00166 {
00167         protected:
00170                 mat preA;
00172                 mat postA;
00174         public:
00176                 BM* _copy_() const {
00177                         KalmanCh* K = new KalmanCh;
00178                         K->set_parameters (A, B, C, D, Q, R);
00179                         K->set_statistics (est->_mu(), est->_R());
00180                         return K;
00181                 }
00183                 void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
00185                 void initialize();
00186 
00200                 void bayes (const vec &dt);
00201                 
00202                 void from_setting(const Setting &set){
00203                         Kalman<chmat>::from_setting(set);
00204                         initialize();
00205                 }
00206 };
00207 UIREGISTER(KalmanCh);
00208 
00214 class EKFfull : public KalmanFull
00215 {
00216         protected:
00218                 shared_ptr<diffbifn> pfxu;
00219 
00221                 shared_ptr<diffbifn> phxu;
00222 
00223         public:
00225                 EKFfull ();
00226 
00228                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
00229 
00231                 void bayes (const vec &dt);
00233                 void set_statistics (const vec &mu0, const mat &P0) {
00234                         est->set_parameters (mu0, P0);
00235                 };
00236                 const mat _R() {
00237                         return est->_R().to_mat();
00238                 }
00239 };
00240 UIREGISTER(EKFfull);
00241 
00242 
00249 class EKFCh : public KalmanCh
00250 {
00251         protected:
00253                 shared_ptr<diffbifn> pfxu;
00254 
00256                 shared_ptr<diffbifn> phxu;
00257         public:
00259                 BM* _copy_() const {
00260                         EKFCh* E = new EKFCh;
00261                         E->set_parameters (pfxu, phxu, Q, R);
00262                         E->set_statistics (est->_mu(), est->_R());
00263                         return E;
00264                 }
00266                 void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
00267 
00269                 void bayes (const vec &dt);
00270 
00271                 void from_setting (const Setting &set);
00272 
00273                 
00274 
00275 };
00276 
00277 UIREGISTER (EKFCh);
00278 SHAREDPTR (EKFCh);
00279 
00280 
00282 
00290 class MultiModel: public BM
00291 {
00292         protected:
00294                 Array<EKFCh*> Models;
00296                 vec w;
00298                 vec _lls;
00300                 int policy;
00302                 enorm<chmat> est;
00303         public:
00304                 void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
00305                         Models = A;
00306                         w.set_length (A.length());
00307                         _lls.set_length (A.length());
00308                         policy = pol0;
00309 
00310                         est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
00311                         est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
00312                 }
00313                 void bayes (const vec &dt) {
00314                         int n = Models.length();
00315                         int i;
00316                         for (i = 0; i < n; i++) {
00317                                 Models (i)->bayes (dt);
00318                                 _lls (i) = Models (i)->_ll();
00319                         }
00320                         double mlls = max (_lls);
00321                         w = exp (_lls - mlls);
00322                         w /= sum (w);   
00323                         
00324                         switch (policy) {
00325                                 case 1: {
00326                                         int mi = max_index (w);
00327                                         const enorm<chmat> &st = Models (mi)->posterior() ;
00328                                         est.set_parameters (st.mean(), st._R());
00329                                 }
00330                                 break;
00331                                 default:
00332                                         bdm_error ("unknown policy");
00333                         }
00334                         
00335                         for (i = 0; i < n; i++) {
00336                                 Models (i)->set_statistics (est.mean(), est._R());
00337                         }
00338                 }
00340                 const enorm<chmat>& posterior() const {
00341                         return est;
00342                 }
00343 
00344                 void from_setting (const Setting &set);
00345 
00346                 
00347 
00348 };
00349 
00350 UIREGISTER (MultiModel);
00351 SHAREDPTR (MultiModel);
00352 
00354 
00361 template<class sq_T>
00362 shared_ptr<StateSpace<fsqmat> > to_state_space(const mlnorm<sq_T > &ml, ivec &theta_in_A, ivec &theta_in_C){
00363         
00364                         
00365                         ivec yids= unique(ml._rv()._ids()); 
00366                         ivec dids= unique(ml._rvc()._ids()); 
00367                         ivec uids=unique_complement(dids, yids);
00368         
00369                         RV yrv_uni = RV(yids,zeros_i(yids.length()));
00370                         RV urv_uni = RV(uids,zeros_i(yids.length()));
00371                         
00372                         bdm_assert_debug(yrv_uni._dsize()==urv_uni._dsize()==1, "Only for SISO so far..." );
00373                                         
00374                         RV xrv; 
00375                         RV Crv; 
00376                         int td=ml._rvc().mintd();
00377                         for (int t=-1;t>=td;t--){
00378                                 xrv.add(yrv_uni);
00379                                 Crv.add(urv_uni);
00380                         }
00381                         
00382                         int dimx = xrv._dsize();
00383                         
00384                         theta_in_A = ml._rv().dataind(xrv);
00385                         theta_in_C = ml._rvc().dataind(xrv);
00386                         
00387                         
00388                         vec A1row = zeros(xrv._dsize());
00389                         vec C1row = zeros(xrv._dsize());
00390                         vec theta = ml._A().get_row(0); 
00391                         set_subvector( A1row, theta_in_A, theta);
00392                         set_subvector( C1row, theta_in_C, theta);
00393 
00394                         StateSpace<fsqmat> stsp=new StateSpace<fsqmat>();
00395                         mat A=zeros(dimx,dimx);
00396                         A.set_row(0,A1row);
00397                         for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} 
00398                         mat B=zeros(dimx,1);
00399                         B(0) = 1.0;
00400                         mat C=zeros(1,dimx);
00401                         C.set_row(0,C1row);
00402                         stsp.set_parameters(A,B,C,zeros(1,1), zeros(dimx,dimx), ml._R());
00403                         return stsp;
00404 }
00405 
00407 
00408 template<class sq_T>
00409 void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
00410 {
00411         dimx = A0.rows();
00412         dimu = B0.cols();
00413         dimy = C0.rows();
00414 
00415         A = A0;
00416         B = B0;
00417         C = C0;
00418         D = D0;
00419         R = R0;
00420         Q = Q0;
00421         validate();
00422 }
00423 
00424 template<class sq_T>
00425 void StateSpace<sq_T>::validate(){
00426         bdm_assert_debug (A.cols() == dimx, "KalmanFull: A is not square");
00427         bdm_assert_debug (B.rows() == dimx, "KalmanFull: B is not compatible");
00428         bdm_assert_debug (C.cols() == dimx, "KalmanFull: C is not square");
00429         bdm_assert_debug ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
00430         bdm_assert_debug ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
00431         bdm_assert_debug ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
00432 }
00433 
00434 }
00435 #endif // KF_H
00436 
00437