kalman.h
Go to the documentation of this file.00001 00013 #ifndef KF_H 00014 #define KF_H 00015 00016 00017 #include "../math/functions.h" 00018 #include "../stat/exp_family.h" 00019 #include "../math/chmat.h" 00020 #include "../base/user_info.h" 00021 //#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h> 00022 00023 namespace bdm { 00024 00032 template<class sq_T> 00033 class StateSpace { 00034 protected: 00036 mat A; 00038 mat B; 00040 mat C; 00042 mat D; 00044 sq_T Q; 00046 sq_T R; 00047 public: 00048 StateSpace() : A(), B(), C(), D(), Q(), R() {} 00050 StateSpace ( const StateSpace<sq_T> &S0 ) : A ( S0.A ), B ( S0.B ), C ( S0.C ), D ( S0.D ), Q ( S0.Q ), R ( S0.R ) {} 00052 void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0 ); 00054 void validate(); 00056 void from_setting ( const Setting &set ) { 00057 UI::get ( A, set, "A", UI::compulsory ); 00058 UI::get ( B, set, "B", UI::compulsory ); 00059 UI::get ( C, set, "C", UI::compulsory ); 00060 UI::get ( D, set, "D", UI::compulsory ); 00061 mat Qtm, Rtm; // full matrices 00062 if ( !UI::get ( Qtm, set, "Q", UI::optional ) ) { 00063 vec dq; 00064 UI::get ( dq, set, "dQ", UI::compulsory ); 00065 Qtm = diag ( dq ); 00066 } 00067 if ( !UI::get ( Rtm, set, "R", UI::optional ) ) { 00068 vec dr; 00069 UI::get ( dr, set, "dQ", UI::compulsory ); 00070 Rtm = diag ( dr ); 00071 } 00072 R = Rtm; // automatic conversion to square-root form 00073 Q = Qtm; 00074 00075 validate(); 00076 } 00077 void to_setting ( Setting &set ) const { 00078 UI::save( A, set, "A" ); 00079 UI::save( B, set, "B" ); 00080 UI::save( C, set, "C" ); 00081 UI::save( D, set, "D" ); 00082 UI::save( Q.to_mat(), set, "Q" ); 00083 UI::save( R.to_mat(), set, "R" ); 00084 } 00086 const mat& _A() const { 00087 return A; 00088 } 00090 const mat& _B() const { 00091 return B; 00092 } 00094 const mat& _C() const { 00095 return C; 00096 } 00098 const mat& _D() const { 00099 return D; 00100 } 00102 const sq_T& _Q() const { 00103 return Q; 00104 } 00106 const sq_T& _R() const { 00107 return R; 00108 } 00109 }; 00110 00112 template<class sq_T> 00113 class Kalman: public BM, public StateSpace<sq_T> { 00114 protected: 00116 RV yrv; 00118 mat _K; 00120 enorm<sq_T> est; 00122 enorm<sq_T> fy; 00123 public: 00124 Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(), est() {} 00126 Kalman<sq_T> ( const Kalman<sq_T> &K0 ) : BM ( K0 ), StateSpace<sq_T> ( K0 ), yrv ( K0.yrv ), _K ( K0._K ), est ( K0.est ), fy ( K0.fy ) {} 00128 void set_statistics ( const vec &mu0, const mat &P0 ) { 00129 est.set_parameters ( mu0, P0 ); 00130 }; 00132 void set_statistics ( const vec &mu0, const sq_T &P0 ) { 00133 est.set_parameters ( mu0, P0 ); 00134 }; 00136 const enorm<sq_T>& posterior() const { 00137 return est; 00138 } 00139 00150 void from_setting ( const Setting &set ) { 00151 StateSpace<sq_T>::from_setting ( set ); 00152 BM::from_setting(set); 00153 00154 shared_ptr<epdf> pri=UI::build<epdf>(set,"prior",UI::compulsory); 00155 //bdm_assert(pri->dimension()==); 00156 set_statistics ( pri->mean(), pri->covariance() ); 00157 } 00158 void to_setting ( Setting &set ) const { 00159 StateSpace<sq_T>::to_setting ( set ); 00160 BM::to_setting(set); 00161 00162 UI::save(est, set, "prior"); 00163 } 00165 void validate() { 00166 StateSpace<sq_T>::validate(); 00167 dimy = this->C.rows(); 00168 dimc = this->B.cols(); 00169 set_dim ( this->A.rows() ); 00170 00171 bdm_assert ( est.dimension(), "Statistics and model parameters mismatch" ); 00172 } 00173 00174 }; 00179 class KalmanFull : public Kalman<fsqmat> { 00180 public: 00182 KalmanFull() : Kalman<fsqmat>() {}; 00184 void bayes ( const vec &yt, const vec &cond = empty_vec ); 00185 00186 virtual KalmanFull* _copy() const { 00187 KalmanFull* K = new KalmanFull; 00188 K->set_parameters ( A, B, C, D, Q, R ); 00189 K->set_statistics ( est._mu(), est._R() ); 00190 return K; 00191 } 00192 }; 00193 UIREGISTER ( KalmanFull ); 00194 00195 00203 class KalmanCh : public Kalman<chmat> { 00204 protected: 00207 mat preA; 00209 mat postA; 00211 public: 00213 virtual KalmanCh* _copy() const { 00214 KalmanCh* K = new KalmanCh; 00215 K->set_parameters ( A, B, C, D, Q, R ); 00216 K->set_statistics ( est._mu(), est._R() ); 00217 K->validate(); 00218 return K; 00219 } 00221 void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0 ); 00223 void initialize(); 00224 00238 void bayes ( const vec &yt, const vec &cond = empty_vec ); 00239 00248 void from_setting ( const Setting &set ) { 00249 Kalman<chmat>::from_setting ( set ); 00250 } 00251 00252 void validate() { 00253 Kalman<chmat>::validate(); 00254 initialize(); 00255 } 00256 }; 00257 UIREGISTER ( KalmanCh ); 00258 00264 class EKFfull : public KalmanFull { 00265 protected: 00267 shared_ptr<diffbifn> pfxu; 00268 00270 shared_ptr<diffbifn> phxu; 00271 00272 public: 00274 EKFfull (); 00275 00277 void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0 ); 00278 00280 void bayes ( const vec &yt, const vec &cond = empty_vec ); 00282 void set_statistics ( const vec &mu0, const mat &P0 ) { 00283 est.set_parameters ( mu0, P0 ); 00284 }; 00286 const mat _R() { 00287 return est._R().to_mat(); 00288 } 00289 00313 void from_setting ( const Setting &set ) { 00314 BM::from_setting ( set ); 00315 shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory ); 00316 shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory ); 00317 00318 //statistics 00319 int dim = IM->dimension(); 00320 vec mu0; 00321 if ( !UI::get ( mu0, set, "mu0" ) ) 00322 mu0 = zeros ( dim ); 00323 00324 mat P0; 00325 vec dP0; 00326 if ( UI::get ( dP0, set, "dP0" ) ) 00327 P0 = diag ( dP0 ); 00328 else if ( !UI::get ( P0, set, "P0" ) ) 00329 P0 = eye ( dim ); 00330 00331 //parameters 00332 vec dQ, dR; 00333 UI::get ( dQ, set, "dQ", UI::compulsory ); 00334 UI::get ( dR, set, "dR", UI::compulsory ); 00335 set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) ); 00336 00337 set_statistics ( mu0, P0 ); 00338 00339 // pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory); 00340 // phxu = UI::build<diffbifn>(set, "OM", UI::compulsory); 00341 // 00342 // mat R0; 00343 // UI::get(R0, set, "R",UI::compulsory); 00344 // mat Q0; 00345 // UI::get(Q0, set, "Q",UI::compulsory); 00346 // 00347 // 00348 // mat P0; vec mu0; 00349 // UI::get(mu0, set, "mu0", UI::optional); 00350 // UI::get(P0, set, "P0", UI::optional); 00351 // set_statistics(mu0,P0); 00352 // // Initial values 00353 // UI::get (yrv, set, "yrv", UI::optional); 00354 // UI::get (urv, set, "urv", UI::optional); 00355 // set_drv(concat(yrv,urv)); 00356 // 00357 // // setup StateSpace 00358 // pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true); 00359 // phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true); 00360 // 00361 } 00362 00363 void validate() { 00364 // KalmanFull::validate(); 00365 00366 // check stats and IM and OM 00367 } 00368 }; 00369 UIREGISTER ( EKFfull ); 00370 00371 00378 class EKFCh : public KalmanCh { 00379 LOG_LEVEL(EKFCh,logCh); 00380 protected: 00382 shared_ptr<diffbifn> pfxu; 00383 00385 shared_ptr<diffbifn> phxu; 00386 public: 00388 EKFCh* _copy() const { 00389 return new EKFCh(*this); 00390 } 00392 void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0 ); 00393 00395 void bayes ( const vec &yt, const vec &cond = empty_vec ); 00396 00419 void from_setting ( const Setting &set ); 00420 00421 void log_register(logger &L, const string &prefix){ 00422 BM::log_register ( L, prefix ); 00423 00424 L.add_vector ( log_level, logCh, RV ("Ch", dimension()*dimension() ), prefix ); 00425 }; 00426 00427 void log_write() const{ 00428 BM::log_write(); 00429 if ( log_level[logCh] ) { 00430 vec v(est._R()._Ch()._data(), dimension()*dimension()); 00431 if (v(0)<0) 00432 v= -v; 00433 log_level.store( logCh, v); 00434 } 00435 00436 }; 00437 00438 void validate() {}; 00439 // TODO dodelat void to_setting( Setting &set ) const; 00440 00441 }; 00442 00443 UIREGISTER ( EKFCh ); 00444 SHAREDPTR ( EKFCh ); 00445 00447 class EKF_UD : public BM { 00448 protected: 00450 LOG_LEVEL(EKF_UD,logU, logG, logD,logA,logP,logC); 00452 shared_ptr<diffbifn> pfxu; 00453 00455 shared_ptr<diffbifn> phxu; 00456 00458 mat U; 00460 vec D; 00461 00462 mat A; 00463 mat C; 00464 mat Q; 00465 vec R; 00466 00467 enorm<ldmat> est; 00468 public: 00469 00471 EKF_UD* _copy() const { 00472 return new EKF_UD(*this); 00473 } 00474 00475 const enorm<ldmat>& posterior()const{return est;}; 00476 00477 enorm<ldmat>& prior() { 00478 return const_cast<enorm<ldmat>&>(posterior()); 00479 } 00480 00481 EKF_UD(){} 00482 00483 00484 EKF_UD(const EKF_UD &E0): pfxu(E0.pfxu),phxu(E0.phxu), U(E0.U), D(E0.D){} 00485 00487 void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const vec R0 ); 00488 00490 void bayes ( const vec &yt, const vec &cond = empty_vec ); 00491 00492 void log_register ( bdm::logger& L, const string& prefix ){ 00493 BM::log_register ( L, prefix ); 00494 00495 if ( log_level[logU] ) 00496 L.add_vector ( log_level, logU, RV ( dimension()*dimension() ), prefix ); 00497 if ( log_level[logG] ) 00498 L.add_vector ( log_level, logG, RV ( dimension()*dimension() ), prefix ); 00499 if ( log_level[logD] ) 00500 L.add_vector ( log_level, logD, RV ( dimension()), prefix ); 00501 if ( log_level[logC] ) 00502 L.add_vector ( log_level, logC, RV ( dimension()*dimensiony()), prefix ); 00503 00504 L.add_vector ( log_level, logA, RV ( dimension()*dimension()), prefix ); 00505 L.add_vector ( log_level, logP, RV ( dimension()*dimension()), prefix ); 00506 00507 } 00530 void from_setting ( const Setting &set ); 00531 00532 void validate() {}; 00533 // TODO dodelat void to_setting( Setting &set ) const; 00534 00535 }; 00536 UIREGISTER(EKF_UD); 00537 00538 class UKFCh : public EKFCh{ 00539 public: 00540 double kappa; 00541 00542 void bayes ( const vec &yt, const vec &cond = empty_vec ){ 00543 00544 vec &_mu = est._mu(); 00545 chmat &_P = est._R(); 00546 chmat &_Ry = fy._R(); 00547 vec &_yp = fy._mu(); 00548 00549 int dim = dimension(); 00550 int dim2 = 1+dim+dim; 00551 00552 double npk =dim+kappa;//n+kappa 00553 mat Xi(dim,dim2); 00554 vec w=ones(dim2)* 0.5/npk; 00555 w(0) = (npk-dim)/npk; // mean is special 00556 00557 //step 1. 00558 int i; 00559 Xi.set_col(0,_mu); 00560 00561 for ( i=0;i<dim; i++){ 00562 vec tmp=sqrt(npk)*_P._Ch().get_col(i); 00563 Xi.set_col(i+1, _mu+tmp); 00564 Xi.set_col(i+1+dim, _mu-tmp); 00565 } 00566 00567 // step 2. 00568 mat Xik(dim,dim2); 00569 for (i=0; i<dim2; i++){ 00570 Xik.set_col(i, pfxu->eval(Xi.get_col(i), cond)); 00571 } 00572 00573 //step 3 00574 vec xp=zeros(dim); 00575 for (i=0;i<dim2;i++){ 00576 xp += w(i) * Xik.get_col(i); 00577 } 00578 00579 //step 4 00580 mat P4=Q.to_mat(); 00581 vec tmp; 00582 for (i=0;i<dim2;i++){ 00583 tmp = Xik.get_col(i)-xp; 00584 P4+=w(i)*outer_product(tmp,tmp); 00585 } 00586 00587 //step 5 00588 mat Yi(dimy,dim2); 00589 for (i=0; i<dim2; i++){ 00590 Yi.set_col(i, phxu->eval(Xik.get_col(i), cond)); 00591 } 00592 //step 6 00593 _yp.clear(); 00594 for (i=0;i<dim2;i++){ 00595 _yp += w(i) * Yi.get_col(i); 00596 } 00597 //step 7 00598 mat Pvv=R.to_mat(); 00599 for (i=0;i<dim2;i++){ 00600 tmp = Yi.get_col(i)-_yp; 00601 Pvv+=w(i)*outer_product(tmp,tmp); 00602 } 00603 _Ry._Ch() = chol(Pvv); 00604 00605 // step 8 00606 mat Pxy=zeros(dim,dimy); 00607 for (i=0;i<dim2;i++){ 00608 Pxy+=w(i)*outer_product(Xi.get_col(i)-xp, Yi.get_col(i)-_yp); 00609 } 00610 mat iRy=inv(_Ry._Ch()); 00611 00612 //filtering????? -- correction 00613 mat K=Pxy*iRy*iRy.T(); 00614 mat K2=Pxy*inv(_Ry.to_mat()); 00615 00617 _mu = xp + K*(yt - _yp); 00618 00619 if ( _mu ( 3 ) >pi ) _mu ( 3 )-=2*pi; 00620 if ( _mu ( 3 ) <-pi ) _mu ( 3 ) +=2*pi; 00621 // fill the space in Ppred; 00622 _P._Ch()=chol(P4-K*_Ry.to_mat()*K.T()); 00623 } 00624 void from_setting(const Setting &set){ 00625 EKFCh::from_setting(set); 00626 kappa = 1.0; 00627 UI::get(kappa,set,"kappa"); 00628 } 00629 }; 00630 UIREGISTER(UKFCh); 00631 00633 00641 class MultiModel: public BM { 00642 protected: 00644 Array<EKFCh*> Models; 00646 vec w; 00648 vec _lls; 00650 int policy; 00652 enorm<chmat> est; 00653 public: 00655 void set_parameters ( Array<EKFCh*> A, int pol0 = 1 ) { 00656 Models = A;//TODO: test if evalll is set 00657 w.set_length ( A.length() ); 00658 _lls.set_length ( A.length() ); 00659 policy = pol0; 00660 00661 est.set_rv ( RV ( "MM", A ( 0 )->posterior().dimension(), 0 ) ); 00662 est.set_parameters ( A ( 0 )->posterior().mean(), A ( 0 )->posterior()._R() ); 00663 } 00664 void bayes ( const vec &yt, const vec &cond = empty_vec ) { 00665 int n = Models.length(); 00666 int i; 00667 for ( i = 0; i < n; i++ ) { 00668 Models ( i )->bayes ( yt ); 00669 _lls ( i ) = Models ( i )->_ll(); 00670 } 00671 double mlls = max ( _lls ); 00672 w = exp ( _lls - mlls ); 00673 w /= sum ( w ); //normalization 00674 //set statistics 00675 switch ( policy ) { 00676 case 1: { 00677 int mi = max_index ( w ); 00678 const enorm<chmat> &st = Models ( mi )->posterior() ; 00679 est.set_parameters ( st.mean(), st._R() ); 00680 } 00681 break; 00682 default: 00683 bdm_error ( "unknown policy" ); 00684 } 00685 // copy result to all models 00686 for ( i = 0; i < n; i++ ) { 00687 Models ( i )->set_statistics ( est.mean(), est._R() ); 00688 } 00689 } 00691 const enorm<chmat>& posterior() const { 00692 return est; 00693 } 00694 00695 void from_setting ( const Setting &set ); 00696 00697 }; 00698 UIREGISTER ( MultiModel ); 00699 SHAREDPTR ( MultiModel ); 00700 00702 00710 //template<class sq_T> 00711 class StateCanonical: public StateSpace<fsqmat> { 00712 protected: 00714 datalink_part th2A; 00716 datalink_part th2C; 00718 datalink_part th2D; 00720 vec A1row; 00722 vec C1row; 00724 vec D1row; 00725 00726 public: 00728 void connect_mlnorm ( const mlnorm<fsqmat> &ml ); 00729 00731 void update_from ( const mlnorm<fsqmat> &ml ); 00732 }; 00746 class StateFromARX: public StateSpace<chmat> { 00747 protected: 00749 datalink_part th2A; 00751 datalink_part th2B; 00753 void diagonal_part ( mat &A, int r, int c, int n ) { 00754 for ( int i = 0; i < n; i++ ) { 00755 A ( r, c ) = 1.0; 00756 r++; 00757 c++; 00758 } 00759 }; 00761 bool have_constant; 00762 public: 00767 void connect_mlnorm ( const mlnorm<chmat> &ml, RV &xrv, RV &urv ); 00768 00770 void update_from ( const mlnorm<chmat> &ml ); 00771 00773 bool _have_constant() const { 00774 return have_constant; 00775 } 00776 }; 00777 00779 00780 template<class sq_T> 00781 void StateSpace<sq_T>::set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const sq_T &Q0, const sq_T &R0 ) { 00782 00783 A = A0; 00784 B = B0; 00785 C = C0; 00786 D = D0; 00787 R = R0; 00788 Q = Q0; 00789 validate(); 00790 } 00791 00792 template<class sq_T> 00793 void StateSpace<sq_T>::validate() { 00794 bdm_assert ( A.cols() == A.rows(), "KalmanFull: A is not square" ); 00795 bdm_assert ( B.rows() == A.rows(), "KalmanFull: B is not compatible" ); 00796 bdm_assert ( C.cols() == A.rows(), "KalmanFull: C is not compatible" ); 00797 bdm_assert ( ( D.rows() == C.rows() ) && ( D.cols() == B.cols() ), "KalmanFull: D is not compatible" ); 00798 bdm_assert ( ( Q.cols() == A.rows() ) && ( Q.rows() == A.rows() ), "KalmanFull: Q is not compatible" ); 00799 bdm_assert ( ( R.cols() == C.rows() ) && ( R.rows() == C.rows() ), "KalmanFull: R is not compatible" ); 00800 } 00801 00802 } 00803 #endif // KF_H 00804 00805
Generated on 2 Dec 2013 for mixpp by 1.4.7