root/library/bdm/estim/kalman.h @ 653

Revision 653, 14.1 kB (checked in by smidl, 15 years ago)

corrections in Kalman and particles

  • Property svn:eol-style set to native
RevLine 
[7]1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
[262]16
[384]17#include "../math/functions.h"
18#include "../stat/exp_family.h"
[37]19#include "../math/chmat.h"
[384]20#include "../base/user_info.h"
[7]21
[583]22namespace bdm
23{
[7]24
[477]25/*!
[583]26 * \brief Basic elements of linear state-space model
[32]27
[477]28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
[583]31 */
[477]32template<class sq_T>
[583]33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
[653]56                StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
[583]57                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
58                void validate();
59                //! not virtual in this case
60                void from_setting (const Setting &set) {
61                        UI::get (A, set, "A", UI::compulsory);
62                        UI::get (B, set, "B", UI::compulsory);
63                        UI::get (C, set, "C", UI::compulsory);
64                        UI::get (D, set, "D", UI::compulsory);
65                        mat Qtm, Rtm;
66                        if(!UI::get(Qtm, set, "Q", UI::optional)){
67                                vec dq;
68                                UI::get(dq, set, "dQ", UI::compulsory);
69                                Qtm=diag(dq);
70                        }
71                        if(!UI::get(Rtm, set, "R", UI::optional)){
72                                vec dr;
73                                UI::get(dr, set, "dQ", UI::compulsory);
74                                Rtm=diag(dr);
75                        }
76                        R=Rtm; // automatic conversion
77                        Q=Qtm; 
78                       
79                        validate();
80                }               
[586]81                //! access function
82                int _dimx(){return dimx;}
83                //! access function
84                int _dimy(){return dimy;}
85                //! access function
86                int _dimu(){return dimu;}
87                //! access function
[605]88                const mat& _A() const {return A;}
[586]89                //! access function
[605]90                const mat& _B()const {return B;}
[586]91                //! access function
[605]92                const mat& _C()const {return C;}
[586]93                //! access function
[605]94                const mat& _D()const {return D;}
[586]95                //! access function
[605]96                const sq_T& _Q()const {return Q;}
[586]97                //! access function
[605]98                const sq_T& _R()const {return R;}
[583]99};
[32]100
[583]101//! Common abstract base for Kalman filters
102template<class sq_T>
103class Kalman: public BM, public StateSpace<sq_T>
104{
105        protected:
106                //! id of output
107                RV yrv;
108                //! id of input
109                RV urv;
110                //! Kalman gain
111                mat  _K;
112                //!posterior
113                shared_ptr<enorm<sq_T> > est;
114                //!marginal on data f(y|y)
115                enorm<sq_T>  fy;
116        public:
[653]117                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
118                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K),  est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
[583]119                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
120                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
121                //! posterior
122                const enorm<sq_T>& posterior() const {return *est.get();}
123                //! shared posterior
124                shared_ptr<epdf> shared_posterior() {return est;}
125                //! load basic elements of Kalman from structure
126                void from_setting (const Setting &set) {
127                        StateSpace<sq_T>::from_setting(set);
128                                               
129                        mat P0; vec mu0;
130                        UI::get(mu0, set, "mu0", UI::optional);
131                        UI::get(P0, set,  "P0", UI::optional);
132                        set_statistics(mu0,P0);
133                        // Initial values
134                        UI::get (yrv, set, "yrv", UI::optional);
135                        UI::get (urv, set, "urv", UI::optional);
136                        set_drv(concat(yrv,urv));
137                       
138                        validate();
139                }
140                void validate() {
141                        StateSpace<sq_T>::validate();
[620]142                        bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
[583]143                }
144};
145/*!
146* \brief Basic Kalman filter with full matrices
147*/
[32]148
[583]149class KalmanFull : public Kalman<fsqmat>
150{
151        public:
152                //! For EKFfull;
153                KalmanFull() :Kalman<fsqmat>(){};
154                //! Here dt = [yt;ut] of appropriate dimensions
155                void bayes (const vec &dt);
[653]156                BM* _copy_() const {
157                        KalmanFull* K = new KalmanFull;
158                        K->set_parameters (A, B, C, D, Q, R);
159                        K->set_statistics (est->_mu(), est->_R());
160                        return K;
161                }
[583]162};
[586]163UIREGISTER(KalmanFull);
[32]164
165
[477]166/*! \brief Kalman filter in square root form
[271]167
[477]168Trivial example:
169\include kalman_simple.cpp
170
[583]171Complete constructor:
[477]172*/
[583]173class KalmanCh : public Kalman<chmat>
174{
175        protected:
176                //! @{ \name Internal storage - needs initialize()
177                //! pre array (triangular matrix)
178                mat preA;
179                //! post array (triangular matrix)
180                mat postA;
181                //!@}
182        public:
183                //! copy constructor
184                BM* _copy_() const {
185                        KalmanCh* K = new KalmanCh;
186                        K->set_parameters (A, B, C, D, Q, R);
187                        K->set_statistics (est->_mu(), est->_R());
188                        return K;
189                }
190                //! set parameters for adapt from Kalman
191                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
192                //! initialize internal parametetrs
193                void initialize();
[37]194
[583]195                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
[283]196
[583]197                The following equality hold::\f[
198                \left[\begin{array}{cc}
199                R^{0.5}\\
200                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
201                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
202                R_{y}^{0.5} & KA'\\
203                & P_{t+1|t}^{0.5}\\
204                \\\end{array}\right]\f]
[283]205
[583]206                Thus this object evaluates only predictors! Not filtering densities.
207                */
208                void bayes (const vec &dt);
209               
[586]210                void from_setting(const Setting &set){
[583]211                        Kalman<chmat>::from_setting(set);
212                        initialize();
213                }
[477]214};
[586]215UIREGISTER(KalmanCh);
[37]216
[477]217/*!
218\brief Extended Kalman Filter in full matrices
[62]219
[477]220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
[583]222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
[527]227
[583]228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
[283]230
[583]231        public:
232                //! Default constructor
233                EKFfull ();
[527]234
[583]235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
[527]237
[583]238                //! Here dt = [yt;ut] of appropriate dimensions
239                void bayes (const vec &dt);
240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
242                        est->set_parameters (mu0, P0);
243                };
244                const mat _R() {
245                        return est->_R().to_mat();
246                }
[653]247                void from_setting (const Setting &set) {
248                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
249                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
250                       
251                        //statistics
252                        int dim = IM->dimension();
253                        vec mu0;
254                        if ( !UI::get ( mu0, set, "mu0" ) )
255                                mu0 = zeros ( dim );
256                       
257                        mat P0;
258                        vec dP0;
259                        if ( UI::get ( dP0, set, "dP0" ) )
260                                P0 = diag ( dP0 );
261                        else if ( !UI::get ( P0, set, "P0" ) )
262                                P0 = eye ( dim );
263                       
264                        set_statistics ( mu0, P0 );
265                       
266                        //parameters
267                        vec dQ, dR;
268                        UI::get ( dQ, set, "dQ", UI::compulsory );
269                        UI::get ( dR, set, "dR", UI::compulsory );
270                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
271                       
272                        //connect
273                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
274                        set_drv ( *drv );
275                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
276                        set_rv ( *rv );
277                       
278                        string options;
279                        if ( UI::get ( options, set, "options" ) )
280                                set_options ( options );
281//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
282//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
283//                     
284//                      mat R0;
285//                      UI::get(R0, set, "R",UI::compulsory);
286//                      mat Q0;
287//                      UI::get(Q0, set, "Q",UI::compulsory);
288//                     
289//                     
290//                      mat P0; vec mu0;
291//                      UI::get(mu0, set, "mu0", UI::optional);
292//                      UI::get(P0, set,  "P0", UI::optional);
293//                      set_statistics(mu0,P0);
294//                      // Initial values
295//                      UI::get (yrv, set, "yrv", UI::optional);
296//                      UI::get (urv, set, "urv", UI::optional);
297//                      set_drv(concat(yrv,urv));
298//
299//                      // setup StateSpace
300//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
301//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
302//                     
303                        validate();
304                }
305                void validate() {
306                        // check stats and IM and OM
307                }
[477]308};
[586]309UIREGISTER(EKFfull);
[62]310
[586]311
[477]312/*!
313\brief Extended Kalman Filter in Square root
[37]314
[477]315An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
316*/
[37]317
[583]318class EKFCh : public KalmanCh
319{
320        protected:
321                //! Internal Model f(x,u)
322                shared_ptr<diffbifn> pfxu;
[527]323
[583]324                //! Observation Model h(x,u)
325                shared_ptr<diffbifn> phxu;
326        public:
327                //! copy constructor duplicated - calls different set_parameters
328                BM* _copy_() const {
329                        EKFCh* E = new EKFCh;
330                        E->set_parameters (pfxu, phxu, Q, R);
331                        E->set_statistics (est->_mu(), est->_R());
332                        return E;
333                }
334                //! Set nonlinear functions for mean values and covariance matrices.
335                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
[527]336
[583]337                //! Here dt = [yt;ut] of appropriate dimensions
338                void bayes (const vec &dt);
[357]339
[583]340                void from_setting (const Setting &set);
[357]341
[583]342                // TODO dodelat void to_setting( Setting &set ) const;
[357]343
[477]344};
[37]345
[583]346UIREGISTER (EKFCh);
347SHAREDPTR (EKFCh);
[357]348
349
[7]350//////// INstance
351
[477]352/*! \brief (Switching) Multiple Model
353The model runs several models in parallel and evaluates thier weights (fittness).
[62]354
[477]355The statistics of the resulting density are merged using (geometric?) combination.
[283]356
[477]357The next step is performed with the new statistics for all models.
358*/
[583]359class MultiModel: public BM
360{
361        protected:
362                //! List of models between which we switch
363                Array<EKFCh*> Models;
364                //! vector of model weights
365                vec w;
366                //! cache of model lls
367                vec _lls;
368                //! type of switching policy [1=maximum,2=...]
369                int policy;
370                //! internal statistics
371                enorm<chmat> est;
372        public:
373                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
374                        Models = A;//TODO: test if evalll is set
375                        w.set_length (A.length());
376                        _lls.set_length (A.length());
377                        policy = pol0;
[357]378
[583]379                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
380                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
[477]381                }
[583]382                void bayes (const vec &dt) {
383                        int n = Models.length();
384                        int i;
385                        for (i = 0; i < n; i++) {
386                                Models (i)->bayes (dt);
387                                _lls (i) = Models (i)->_ll();
388                        }
389                        double mlls = max (_lls);
390                        w = exp (_lls - mlls);
391                        w /= sum (w);   //normalization
392                        //set statistics
393                        switch (policy) {
394                                case 1: {
395                                        int mi = max_index (w);
396                                        const enorm<chmat> &st = Models (mi)->posterior() ;
397                                        est.set_parameters (st.mean(), st._R());
398                                }
399                                break;
400                                default:
401                                        bdm_error ("unknown policy");
402                        }
403                        // copy result to all models
404                        for (i = 0; i < n; i++) {
405                                Models (i)->set_statistics (est.mean(), est._R());
406                        }
[477]407                }
[583]408                //! posterior density
409                const enorm<chmat>& posterior() const {
410                        return est;
[477]411                }
[357]412
[583]413                void from_setting (const Setting &set);
[357]414
[583]415                // TODO dodelat void to_setting( Setting &set ) const;
[338]416
[477]417};
[338]418
[583]419UIREGISTER (MultiModel);
420SHAREDPTR (MultiModel);
[357]421
[586]422//! conversion of outer ARX model (mlnorm) to state space model
423/*!
424The model is constructed as:
425\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
426For example, for:
[605]427Using Frobenius form, see [].
428
429For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
[586]430*/
[605]431//template<class sq_T>
432class StateCanonical: public StateSpace<fsqmat>{
433        protected:
434                //! remember connection from theta ->A
435                datalink_part th2A;
436                //! remember connection from theta ->C
437                datalink_part th2C;
438                //! remember connection from theta ->D
439                datalink_part th2D;
440                //!cached first row of A
441                vec A1row;
442                //!cached first row of C
443                vec C1row;
444                //!cached first row of D
445                vec D1row;
446               
447        public:
448                //! set up this object to match given mlnorm
[625]449                void connect_mlnorm(const mlnorm<fsqmat> &ml){
450                        //get ids of yrv                               
[605]451                        const RV &yrv = ml._rv();
452                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
453                        RV rgr0 = ml._rvc().remove_time();
454                        RV urv = rgr0.subt(yrv); 
[586]455                       
456                        //We can do only 1d now... :(
[620]457                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
[605]458
459                        // create names for
[586]460                        RV xrv; //empty
461                        RV Crv; //empty
[605]462                        int td=ml._rvc().mint();
463                        // assuming strictly proper function!!!
[586]464                        for (int t=-1;t>=td;t--){
[605]465                                xrv.add(yrv.copy_t(t));
466                                Crv.add(urv.copy_t(t));
[586]467                        }
468                       
[605]469                        this->dimx = xrv._dsize();
470                        this->dimy = yrv._dsize();
471                        this->dimu = urv._dsize();
[586]472                       
[605]473                        // get mapp
474                        th2A.set_connection(xrv, ml._rvc());
475                        th2C.set_connection(Crv, ml._rvc());
476                        th2D.set_connection(urv, ml._rvc());
477
478                        //set matrix sizes
479                        this->A=zeros(dimx,dimx);
480                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
481                                this->B=zeros(dimx,1);
482                                this->B(0) = 1.0;
483                                this->C=zeros(1,dimx);
484                                this->D=zeros(1,urv._dsize());
485                                this->Q = zeros(dimx,dimx);
486                        // R is set by update
[586]487                       
[605]488                        //set cache
489                        this->A1row = zeros(xrv._dsize());
490                        this->C1row = zeros(xrv._dsize());
491                        this->D1row = zeros(urv._dsize());
492                       
493                        update_from(ml);
494                        validate();
495                };
496                //! fast function to update parameters from ml - not checked for compatibility!!
497                void update_from(const mlnorm<fsqmat> &ml){
498                       
[586]499                        vec theta = ml._A().get_row(0); // this
[605]500                       
501                        th2A.filldown(theta,A1row);
502                        th2C.filldown(theta,C1row);
503                        th2D.filldown(theta,D1row);
[586]504
[605]505                        R = ml._R();
506
[586]507                        A.set_row(0,A1row);
[605]508                        C.set_row(0,C1row+D1row*A1row);
509                        D.set_row(0,D1row);
510                       
511                }
512};
[586]513
[583]514/////////// INSTANTIATION
[357]515
[477]516template<class sq_T>
[583]517void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
518{
[28]519
[583]520        A = A0;
521        B = B0;
522        C = C0;
523        D = D0;
[477]524        R = R0;
525        Q = Q0;
[583]526        validate();
[477]527}
[22]528
[477]529template<class sq_T>
[583]530void StateSpace<sq_T>::validate(){
[653]531        dimx = A.rows();
532        dimu = B.cols();
533        dimy = C.rows();
[620]534        bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
535        bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
536        bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
[653]537        bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
538        bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
539        bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
[583]540}
[22]541
[254]542}
[7]543#endif // KF_H
544
Note: See TracBrowser for help on using the browser.