root/library/bdm/estim/kalman.h @ 686

Revision 686, 14.0 kB (checked in by smidl, 15 years ago)

pmsm using new syntax for bayes

  • Property svn:eol-style set to native
RevLine 
[7]1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
[262]16
[384]17#include "../math/functions.h"
18#include "../stat/exp_family.h"
[37]19#include "../math/chmat.h"
[384]20#include "../base/user_info.h"
[7]21
[583]22namespace bdm
23{
[7]24
[477]25/*!
[583]26 * \brief Basic elements of linear state-space model
[32]27
[477]28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
[583]31 */
[477]32template<class sq_T>
[583]33class StateSpace
34{
35        protected:
36                //! Matrix A
37                mat A;
38                //! Matrix B
39                mat B;
40                //! Matrix C
41                mat C;
42                //! Matrix D
43                mat D;
44                //! Matrix Q in square-root form
45                sq_T Q;
46                //! Matrix R in square-root form
47                sq_T R;
48        public:
[679]49                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
[660]50                //!copy constructor
[679]51                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
[660]52                //! set all matrix parameters
[583]53                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
[660]54                //! validation
[583]55                void validate();
56                //! not virtual in this case
57                void from_setting (const Setting &set) {
58                        UI::get (A, set, "A", UI::compulsory);
59                        UI::get (B, set, "B", UI::compulsory);
60                        UI::get (C, set, "C", UI::compulsory);
61                        UI::get (D, set, "D", UI::compulsory);
62                        mat Qtm, Rtm;
63                        if(!UI::get(Qtm, set, "Q", UI::optional)){
64                                vec dq;
65                                UI::get(dq, set, "dQ", UI::compulsory);
66                                Qtm=diag(dq);
67                        }
68                        if(!UI::get(Rtm, set, "R", UI::optional)){
69                                vec dr;
70                                UI::get(dr, set, "dQ", UI::compulsory);
71                                Rtm=diag(dr);
72                        }
73                        R=Rtm; // automatic conversion
74                        Q=Qtm; 
75                       
76                        validate();
77                }               
[586]78                //! access function
[605]79                const mat& _A() const {return A;}
[586]80                //! access function
[605]81                const mat& _B()const {return B;}
[586]82                //! access function
[605]83                const mat& _C()const {return C;}
[586]84                //! access function
[605]85                const mat& _D()const {return D;}
[586]86                //! access function
[605]87                const sq_T& _Q()const {return Q;}
[586]88                //! access function
[605]89                const sq_T& _R()const {return R;}
[583]90};
[32]91
[583]92//! Common abstract base for Kalman filters
93template<class sq_T>
94class Kalman: public BM, public StateSpace<sq_T>
95{
96        protected:
97                //! id of output
98                RV yrv;
99                //! Kalman gain
100                mat  _K;
101                //!posterior
[679]102                enorm<sq_T> est;
[583]103                //!marginal on data f(y|y)
104                enorm<sq_T>  fy;
105        public:
[679]106                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
[660]107                //! Copy constructor
[679]108                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
[660]109                //!set statistics of the posterior
[679]110                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
[660]111                //!set statistics of the posterior
[679]112                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
[660]113                //! return correctly typed posterior (covariant return)
[679]114                const enorm<sq_T>& posterior() const {return est;}
[583]115                //! load basic elements of Kalman from structure
116                void from_setting (const Setting &set) {
117                        StateSpace<sq_T>::from_setting(set);
118                                               
119                        mat P0; vec mu0;
120                        UI::get(mu0, set, "mu0", UI::optional);
121                        UI::get(P0, set,  "P0", UI::optional);
122                        set_statistics(mu0,P0);
123                        // Initial values
124                        UI::get (yrv, set, "yrv", UI::optional);
[679]125                        UI::get (rvc, set, "urv", UI::optional);
126                        set_yrv(concat(yrv,rvc));
[583]127                       
128                        validate();
129                }
[660]130                //! validate object
[583]131                void validate() {
132                        StateSpace<sq_T>::validate();
[679]133                        dimy = this->C.rows();
134                        dimc = this->B.cols();
135                        set_dim(this->A.rows());
136                       
137                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
[583]138                }
139};
140/*!
141* \brief Basic Kalman filter with full matrices
142*/
[32]143
[583]144class KalmanFull : public Kalman<fsqmat>
145{
146        public:
147                //! For EKFfull;
148                KalmanFull() :Kalman<fsqmat>(){};
149                //! Here dt = [yt;ut] of appropriate dimensions
[679]150                void bayes (const vec &yt, const vec &cond=empty_vec);
[653]151                BM* _copy_() const {
152                        KalmanFull* K = new KalmanFull;
153                        K->set_parameters (A, B, C, D, Q, R);
[679]154                        K->set_statistics (est._mu(), est._R());
[653]155                        return K;
156                }
[583]157};
[586]158UIREGISTER(KalmanFull);
[32]159
160
[477]161/*! \brief Kalman filter in square root form
[271]162
[477]163Trivial example:
164\include kalman_simple.cpp
165
[583]166Complete constructor:
[477]167*/
[583]168class KalmanCh : public Kalman<chmat>
169{
170        protected:
171                //! @{ \name Internal storage - needs initialize()
172                //! pre array (triangular matrix)
173                mat preA;
174                //! post array (triangular matrix)
175                mat postA;
176                //!@}
177        public:
178                //! copy constructor
179                BM* _copy_() const {
180                        KalmanCh* K = new KalmanCh;
181                        K->set_parameters (A, B, C, D, Q, R);
[679]182                        K->set_statistics (est._mu(), est._R());
[681]183                        K->validate();
[583]184                        return K;
185                }
186                //! set parameters for adapt from Kalman
187                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
188                //! initialize internal parametetrs
189                void initialize();
[37]190
[583]191                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
[283]192
[583]193                The following equality hold::\f[
194                \left[\begin{array}{cc}
195                R^{0.5}\\
196                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
197                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
198                R_{y}^{0.5} & KA'\\
199                & P_{t+1|t}^{0.5}\\
200                \\\end{array}\right]\f]
[283]201
[583]202                Thus this object evaluates only predictors! Not filtering densities.
203                */
[679]204                void bayes (const vec &yt, const vec &cond=empty_vec);
[675]205
[586]206                void from_setting(const Setting &set){
[583]207                        Kalman<chmat>::from_setting(set);
[679]208                        validate();
209                }
210                void validate() {
211                        Kalman<chmat>::validate();
[583]212                        initialize();
213                }
[477]214};
[586]215UIREGISTER(KalmanCh);
[37]216
[477]217/*!
218\brief Extended Kalman Filter in full matrices
[62]219
[477]220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
[583]222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
[527]227
[583]228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
[283]230
[583]231        public:
232                //! Default constructor
233                EKFfull ();
[527]234
[583]235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
[527]237
[583]238                //! Here dt = [yt;ut] of appropriate dimensions
[679]239                void bayes (const vec &yt, const vec &cond=empty_vec);
[583]240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
[679]242                        est.set_parameters (mu0, P0);
[583]243                };
[660]244                //! access function
[583]245                const mat _R() {
[679]246                        return est._R().to_mat();
[583]247                }
[686]248                void from_setting (const Setting &set) {                       
249                        BM::from_setting(set);
[653]250                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
251                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
252                       
253                        //statistics
254                        int dim = IM->dimension();
255                        vec mu0;
256                        if ( !UI::get ( mu0, set, "mu0" ) )
257                                mu0 = zeros ( dim );
258                       
259                        mat P0;
260                        vec dP0;
261                        if ( UI::get ( dP0, set, "dP0" ) )
262                                P0 = diag ( dP0 );
263                        else if ( !UI::get ( P0, set, "P0" ) )
264                                P0 = eye ( dim );
265                       
266                        set_statistics ( mu0, P0 );
267                       
268                        //parameters
269                        vec dQ, dR;
270                        UI::get ( dQ, set, "dQ", UI::compulsory );
271                        UI::get ( dR, set, "dR", UI::compulsory );
272                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
[686]273                                               
[653]274                        string options;
275                        if ( UI::get ( options, set, "options" ) )
276                                set_options ( options );
277//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
278//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
279//                     
280//                      mat R0;
281//                      UI::get(R0, set, "R",UI::compulsory);
282//                      mat Q0;
283//                      UI::get(Q0, set, "Q",UI::compulsory);
284//                     
285//                     
286//                      mat P0; vec mu0;
287//                      UI::get(mu0, set, "mu0", UI::optional);
288//                      UI::get(P0, set,  "P0", UI::optional);
289//                      set_statistics(mu0,P0);
290//                      // Initial values
291//                      UI::get (yrv, set, "yrv", UI::optional);
292//                      UI::get (urv, set, "urv", UI::optional);
293//                      set_drv(concat(yrv,urv));
294//
295//                      // setup StateSpace
296//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
297//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
298//                     
299                        validate();
300                }
301                void validate() {
302                        // check stats and IM and OM
303                }
[477]304};
[586]305UIREGISTER(EKFfull);
[62]306
[586]307
[477]308/*!
309\brief Extended Kalman Filter in Square root
[37]310
[477]311An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
312*/
[37]313
[583]314class EKFCh : public KalmanCh
315{
316        protected:
317                //! Internal Model f(x,u)
318                shared_ptr<diffbifn> pfxu;
[527]319
[583]320                //! Observation Model h(x,u)
321                shared_ptr<diffbifn> phxu;
322        public:
323                //! copy constructor duplicated - calls different set_parameters
324                BM* _copy_() const {
325                        EKFCh* E = new EKFCh;
326                        E->set_parameters (pfxu, phxu, Q, R);
[679]327                        E->set_statistics (est._mu(), est._R());
[583]328                        return E;
329                }
330                //! Set nonlinear functions for mean values and covariance matrices.
331                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
[527]332
[583]333                //! Here dt = [yt;ut] of appropriate dimensions
[679]334                void bayes (const vec &yt, const vec &cond=empty_vec);
[357]335
[583]336                void from_setting (const Setting &set);
[357]337
[681]338                void validate(){};
[583]339                // TODO dodelat void to_setting( Setting &set ) const;
[357]340
[477]341};
[37]342
[583]343UIREGISTER (EKFCh);
344SHAREDPTR (EKFCh);
[357]345
346
[7]347//////// INstance
348
[477]349/*! \brief (Switching) Multiple Model
350The model runs several models in parallel and evaluates thier weights (fittness).
[62]351
[477]352The statistics of the resulting density are merged using (geometric?) combination.
[283]353
[477]354The next step is performed with the new statistics for all models.
355*/
[583]356class MultiModel: public BM
357{
358        protected:
359                //! List of models between which we switch
360                Array<EKFCh*> Models;
361                //! vector of model weights
362                vec w;
363                //! cache of model lls
364                vec _lls;
365                //! type of switching policy [1=maximum,2=...]
366                int policy;
367                //! internal statistics
368                enorm<chmat> est;
369        public:
[660]370                //! set internal parameters
[583]371                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
372                        Models = A;//TODO: test if evalll is set
373                        w.set_length (A.length());
374                        _lls.set_length (A.length());
375                        policy = pol0;
[357]376
[583]377                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
378                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
[477]379                }
[679]380                void bayes (const vec &yt, const vec &cond=empty_vec) {
[583]381                        int n = Models.length();
382                        int i;
383                        for (i = 0; i < n; i++) {
[679]384                                Models (i)->bayes (yt);
[583]385                                _lls (i) = Models (i)->_ll();
386                        }
387                        double mlls = max (_lls);
388                        w = exp (_lls - mlls);
389                        w /= sum (w);   //normalization
390                        //set statistics
391                        switch (policy) {
392                                case 1: {
393                                        int mi = max_index (w);
394                                        const enorm<chmat> &st = Models (mi)->posterior() ;
395                                        est.set_parameters (st.mean(), st._R());
396                                }
397                                break;
398                                default:
399                                        bdm_error ("unknown policy");
400                        }
401                        // copy result to all models
402                        for (i = 0; i < n; i++) {
403                                Models (i)->set_statistics (est.mean(), est._R());
404                        }
[477]405                }
[660]406                //! return correctly typed posterior (covariant return)
[583]407                const enorm<chmat>& posterior() const {
408                        return est;
[477]409                }
[357]410
[583]411                void from_setting (const Setting &set);
[357]412
[583]413                // TODO dodelat void to_setting( Setting &set ) const;
[338]414
[477]415};
[338]416
[583]417UIREGISTER (MultiModel);
418SHAREDPTR (MultiModel);
[357]419
[586]420//! conversion of outer ARX model (mlnorm) to state space model
421/*!
422The model is constructed as:
423\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
424For example, for:
[605]425Using Frobenius form, see [].
426
427For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
[586]428*/
[605]429//template<class sq_T>
430class StateCanonical: public StateSpace<fsqmat>{
431        protected:
432                //! remember connection from theta ->A
433                datalink_part th2A;
434                //! remember connection from theta ->C
435                datalink_part th2C;
436                //! remember connection from theta ->D
437                datalink_part th2D;
438                //!cached first row of A
439                vec A1row;
440                //!cached first row of C
441                vec C1row;
442                //!cached first row of D
443                vec D1row;
444               
445        public:
446                //! set up this object to match given mlnorm
[625]447                void connect_mlnorm(const mlnorm<fsqmat> &ml){
448                        //get ids of yrv                               
[605]449                        const RV &yrv = ml._rv();
450                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
451                        RV rgr0 = ml._rvc().remove_time();
452                        RV urv = rgr0.subt(yrv); 
[586]453                       
454                        //We can do only 1d now... :(
[620]455                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
[605]456
457                        // create names for
[586]458                        RV xrv; //empty
459                        RV Crv; //empty
[605]460                        int td=ml._rvc().mint();
461                        // assuming strictly proper function!!!
[586]462                        for (int t=-1;t>=td;t--){
[605]463                                xrv.add(yrv.copy_t(t));
464                                Crv.add(urv.copy_t(t));
[586]465                        }
[679]466                                               
[605]467                        // get mapp
468                        th2A.set_connection(xrv, ml._rvc());
469                        th2C.set_connection(Crv, ml._rvc());
470                        th2D.set_connection(urv, ml._rvc());
471
472                        //set matrix sizes
[679]473                        this->A=zeros(xrv._dsize(),xrv._dsize());
474                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
475                                this->B=zeros(xrv._dsize(),1);
[605]476                                this->B(0) = 1.0;
[679]477                                this->C=zeros(1,xrv._dsize());
[605]478                                this->D=zeros(1,urv._dsize());
[679]479                                this->Q = zeros(xrv._dsize(),xrv._dsize());
[605]480                        // R is set by update
[586]481                       
[605]482                        //set cache
483                        this->A1row = zeros(xrv._dsize());
484                        this->C1row = zeros(xrv._dsize());
485                        this->D1row = zeros(urv._dsize());
486                       
487                        update_from(ml);
488                        validate();
489                };
490                //! fast function to update parameters from ml - not checked for compatibility!!
491                void update_from(const mlnorm<fsqmat> &ml){
492                       
[586]493                        vec theta = ml._A().get_row(0); // this
[605]494                       
495                        th2A.filldown(theta,A1row);
496                        th2C.filldown(theta,C1row);
497                        th2D.filldown(theta,D1row);
[586]498
[605]499                        R = ml._R();
500
[586]501                        A.set_row(0,A1row);
[605]502                        C.set_row(0,C1row+D1row*A1row);
503                        D.set_row(0,D1row);
504                       
505                }
506};
[586]507
[583]508/////////// INSTANTIATION
[357]509
[477]510template<class sq_T>
[583]511void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
512{
[28]513
[583]514        A = A0;
515        B = B0;
516        C = C0;
517        D = D0;
[477]518        R = R0;
519        Q = Q0;
[583]520        validate();
[477]521}
[22]522
[477]523template<class sq_T>
[583]524void StateSpace<sq_T>::validate(){
[679]525        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
526        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
527        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
528        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
529        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
530        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
[583]531}
[22]532
[254]533}
[7]534#endif // KF_H
535
Note: See TracBrowser for help on using the browser.