root/library/bdm/estim/kalman.h @ 681

Revision 681, 14.1 kB (checked in by smidl, 15 years ago)

corrections of Kalman

  • Property svn:eol-style set to native
RevLine 
[7]1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
[262]16
[384]17#include "../math/functions.h"
18#include "../stat/exp_family.h"
[37]19#include "../math/chmat.h"
[384]20#include "../base/user_info.h"
[7]21
[583]22namespace bdm
23{
[7]24
[477]25/*!
[583]26 * \brief Basic elements of linear state-space model
[32]27
[477]28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
[583]31 */
[477]32template<class sq_T>
[583]33class StateSpace
34{
35        protected:
36                //! Matrix A
37                mat A;
38                //! Matrix B
39                mat B;
40                //! Matrix C
41                mat C;
42                //! Matrix D
43                mat D;
44                //! Matrix Q in square-root form
45                sq_T Q;
46                //! Matrix R in square-root form
47                sq_T R;
48        public:
[679]49                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
[660]50                //!copy constructor
[679]51                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
[660]52                //! set all matrix parameters
[583]53                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
[660]54                //! validation
[583]55                void validate();
56                //! not virtual in this case
57                void from_setting (const Setting &set) {
58                        UI::get (A, set, "A", UI::compulsory);
59                        UI::get (B, set, "B", UI::compulsory);
60                        UI::get (C, set, "C", UI::compulsory);
61                        UI::get (D, set, "D", UI::compulsory);
62                        mat Qtm, Rtm;
63                        if(!UI::get(Qtm, set, "Q", UI::optional)){
64                                vec dq;
65                                UI::get(dq, set, "dQ", UI::compulsory);
66                                Qtm=diag(dq);
67                        }
68                        if(!UI::get(Rtm, set, "R", UI::optional)){
69                                vec dr;
70                                UI::get(dr, set, "dQ", UI::compulsory);
71                                Rtm=diag(dr);
72                        }
73                        R=Rtm; // automatic conversion
74                        Q=Qtm; 
75                       
76                        validate();
77                }               
[586]78                //! access function
[605]79                const mat& _A() const {return A;}
[586]80                //! access function
[605]81                const mat& _B()const {return B;}
[586]82                //! access function
[605]83                const mat& _C()const {return C;}
[586]84                //! access function
[605]85                const mat& _D()const {return D;}
[586]86                //! access function
[605]87                const sq_T& _Q()const {return Q;}
[586]88                //! access function
[605]89                const sq_T& _R()const {return R;}
[583]90};
[32]91
[583]92//! Common abstract base for Kalman filters
93template<class sq_T>
94class Kalman: public BM, public StateSpace<sq_T>
95{
96        protected:
97                //! id of output
98                RV yrv;
99                //! Kalman gain
100                mat  _K;
101                //!posterior
[679]102                enorm<sq_T> est;
[583]103                //!marginal on data f(y|y)
104                enorm<sq_T>  fy;
105        public:
[679]106                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
[660]107                //! Copy constructor
[679]108                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
[660]109                //!set statistics of the posterior
[679]110                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
[660]111                //!set statistics of the posterior
[679]112                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
[660]113                //! return correctly typed posterior (covariant return)
[679]114                const enorm<sq_T>& posterior() const {return est;}
[583]115                //! load basic elements of Kalman from structure
116                void from_setting (const Setting &set) {
117                        StateSpace<sq_T>::from_setting(set);
118                                               
119                        mat P0; vec mu0;
120                        UI::get(mu0, set, "mu0", UI::optional);
121                        UI::get(P0, set,  "P0", UI::optional);
122                        set_statistics(mu0,P0);
123                        // Initial values
124                        UI::get (yrv, set, "yrv", UI::optional);
[679]125                        UI::get (rvc, set, "urv", UI::optional);
126                        set_yrv(concat(yrv,rvc));
[583]127                       
128                        validate();
129                }
[660]130                //! validate object
[583]131                void validate() {
132                        StateSpace<sq_T>::validate();
[679]133                        dimy = this->C.rows();
134                        dimc = this->B.cols();
135                        set_dim(this->A.rows());
136                       
137                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
[583]138                }
139};
140/*!
141* \brief Basic Kalman filter with full matrices
142*/
[32]143
[583]144class KalmanFull : public Kalman<fsqmat>
145{
146        public:
147                //! For EKFfull;
148                KalmanFull() :Kalman<fsqmat>(){};
149                //! Here dt = [yt;ut] of appropriate dimensions
[679]150                void bayes (const vec &yt, const vec &cond=empty_vec);
[653]151                BM* _copy_() const {
152                        KalmanFull* K = new KalmanFull;
153                        K->set_parameters (A, B, C, D, Q, R);
[679]154                        K->set_statistics (est._mu(), est._R());
[653]155                        return K;
156                }
[583]157};
[586]158UIREGISTER(KalmanFull);
[32]159
160
[477]161/*! \brief Kalman filter in square root form
[271]162
[477]163Trivial example:
164\include kalman_simple.cpp
165
[583]166Complete constructor:
[477]167*/
[583]168class KalmanCh : public Kalman<chmat>
169{
170        protected:
171                //! @{ \name Internal storage - needs initialize()
172                //! pre array (triangular matrix)
173                mat preA;
174                //! post array (triangular matrix)
175                mat postA;
176                //!@}
177        public:
178                //! copy constructor
179                BM* _copy_() const {
180                        KalmanCh* K = new KalmanCh;
181                        K->set_parameters (A, B, C, D, Q, R);
[679]182                        K->set_statistics (est._mu(), est._R());
[681]183                        K->validate();
[583]184                        return K;
185                }
186                //! set parameters for adapt from Kalman
187                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
188                //! initialize internal parametetrs
189                void initialize();
[37]190
[583]191                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
[283]192
[583]193                The following equality hold::\f[
194                \left[\begin{array}{cc}
195                R^{0.5}\\
196                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
197                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
198                R_{y}^{0.5} & KA'\\
199                & P_{t+1|t}^{0.5}\\
200                \\\end{array}\right]\f]
[283]201
[583]202                Thus this object evaluates only predictors! Not filtering densities.
203                */
[679]204                void bayes (const vec &yt, const vec &cond=empty_vec);
[675]205
[586]206                void from_setting(const Setting &set){
[583]207                        Kalman<chmat>::from_setting(set);
[679]208                        validate();
209                }
210                void validate() {
211                        Kalman<chmat>::validate();
[583]212                        initialize();
213                }
[477]214};
[586]215UIREGISTER(KalmanCh);
[37]216
[477]217/*!
218\brief Extended Kalman Filter in full matrices
[62]219
[477]220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
[583]222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
[527]227
[583]228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
[283]230
[583]231        public:
232                //! Default constructor
233                EKFfull ();
[527]234
[583]235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
[527]237
[583]238                //! Here dt = [yt;ut] of appropriate dimensions
[679]239                void bayes (const vec &yt, const vec &cond=empty_vec);
[583]240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
[679]242                        est.set_parameters (mu0, P0);
[583]243                };
[660]244                //! access function
[583]245                const mat _R() {
[679]246                        return est._R().to_mat();
[583]247                }
[653]248                void from_setting (const Setting &set) {
249                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
250                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
251                       
252                        //statistics
253                        int dim = IM->dimension();
254                        vec mu0;
255                        if ( !UI::get ( mu0, set, "mu0" ) )
256                                mu0 = zeros ( dim );
257                       
258                        mat P0;
259                        vec dP0;
260                        if ( UI::get ( dP0, set, "dP0" ) )
261                                P0 = diag ( dP0 );
262                        else if ( !UI::get ( P0, set, "P0" ) )
263                                P0 = eye ( dim );
264                       
265                        set_statistics ( mu0, P0 );
266                       
267                        //parameters
268                        vec dQ, dR;
269                        UI::get ( dQ, set, "dQ", UI::compulsory );
270                        UI::get ( dR, set, "dR", UI::compulsory );
271                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
272                       
273                        //connect
274                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
[679]275                        set_yrv ( *drv );
[653]276                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
277                        set_rv ( *rv );
278                       
279                        string options;
280                        if ( UI::get ( options, set, "options" ) )
281                                set_options ( options );
282//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
283//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
284//                     
285//                      mat R0;
286//                      UI::get(R0, set, "R",UI::compulsory);
287//                      mat Q0;
288//                      UI::get(Q0, set, "Q",UI::compulsory);
289//                     
290//                     
291//                      mat P0; vec mu0;
292//                      UI::get(mu0, set, "mu0", UI::optional);
293//                      UI::get(P0, set,  "P0", UI::optional);
294//                      set_statistics(mu0,P0);
295//                      // Initial values
296//                      UI::get (yrv, set, "yrv", UI::optional);
297//                      UI::get (urv, set, "urv", UI::optional);
298//                      set_drv(concat(yrv,urv));
299//
300//                      // setup StateSpace
301//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
302//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
303//                     
304                        validate();
305                }
306                void validate() {
307                        // check stats and IM and OM
308                }
[477]309};
[586]310UIREGISTER(EKFfull);
[62]311
[586]312
[477]313/*!
314\brief Extended Kalman Filter in Square root
[37]315
[477]316An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
317*/
[37]318
[583]319class EKFCh : public KalmanCh
320{
321        protected:
322                //! Internal Model f(x,u)
323                shared_ptr<diffbifn> pfxu;
[527]324
[583]325                //! Observation Model h(x,u)
326                shared_ptr<diffbifn> phxu;
327        public:
328                //! copy constructor duplicated - calls different set_parameters
329                BM* _copy_() const {
330                        EKFCh* E = new EKFCh;
331                        E->set_parameters (pfxu, phxu, Q, R);
[679]332                        E->set_statistics (est._mu(), est._R());
[583]333                        return E;
334                }
335                //! Set nonlinear functions for mean values and covariance matrices.
336                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
[527]337
[583]338                //! Here dt = [yt;ut] of appropriate dimensions
[679]339                void bayes (const vec &yt, const vec &cond=empty_vec);
[357]340
[583]341                void from_setting (const Setting &set);
[357]342
[681]343                void validate(){};
[583]344                // TODO dodelat void to_setting( Setting &set ) const;
[357]345
[477]346};
[37]347
[583]348UIREGISTER (EKFCh);
349SHAREDPTR (EKFCh);
[357]350
351
[7]352//////// INstance
353
[477]354/*! \brief (Switching) Multiple Model
355The model runs several models in parallel and evaluates thier weights (fittness).
[62]356
[477]357The statistics of the resulting density are merged using (geometric?) combination.
[283]358
[477]359The next step is performed with the new statistics for all models.
360*/
[583]361class MultiModel: public BM
362{
363        protected:
364                //! List of models between which we switch
365                Array<EKFCh*> Models;
366                //! vector of model weights
367                vec w;
368                //! cache of model lls
369                vec _lls;
370                //! type of switching policy [1=maximum,2=...]
371                int policy;
372                //! internal statistics
373                enorm<chmat> est;
374        public:
[660]375                //! set internal parameters
[583]376                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
377                        Models = A;//TODO: test if evalll is set
378                        w.set_length (A.length());
379                        _lls.set_length (A.length());
380                        policy = pol0;
[357]381
[583]382                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
383                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
[477]384                }
[679]385                void bayes (const vec &yt, const vec &cond=empty_vec) {
[583]386                        int n = Models.length();
387                        int i;
388                        for (i = 0; i < n; i++) {
[679]389                                Models (i)->bayes (yt);
[583]390                                _lls (i) = Models (i)->_ll();
391                        }
392                        double mlls = max (_lls);
393                        w = exp (_lls - mlls);
394                        w /= sum (w);   //normalization
395                        //set statistics
396                        switch (policy) {
397                                case 1: {
398                                        int mi = max_index (w);
399                                        const enorm<chmat> &st = Models (mi)->posterior() ;
400                                        est.set_parameters (st.mean(), st._R());
401                                }
402                                break;
403                                default:
404                                        bdm_error ("unknown policy");
405                        }
406                        // copy result to all models
407                        for (i = 0; i < n; i++) {
408                                Models (i)->set_statistics (est.mean(), est._R());
409                        }
[477]410                }
[660]411                //! return correctly typed posterior (covariant return)
[583]412                const enorm<chmat>& posterior() const {
413                        return est;
[477]414                }
[357]415
[583]416                void from_setting (const Setting &set);
[357]417
[583]418                // TODO dodelat void to_setting( Setting &set ) const;
[338]419
[477]420};
[338]421
[583]422UIREGISTER (MultiModel);
423SHAREDPTR (MultiModel);
[357]424
[586]425//! conversion of outer ARX model (mlnorm) to state space model
426/*!
427The model is constructed as:
428\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
429For example, for:
[605]430Using Frobenius form, see [].
431
432For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
[586]433*/
[605]434//template<class sq_T>
435class StateCanonical: public StateSpace<fsqmat>{
436        protected:
437                //! remember connection from theta ->A
438                datalink_part th2A;
439                //! remember connection from theta ->C
440                datalink_part th2C;
441                //! remember connection from theta ->D
442                datalink_part th2D;
443                //!cached first row of A
444                vec A1row;
445                //!cached first row of C
446                vec C1row;
447                //!cached first row of D
448                vec D1row;
449               
450        public:
451                //! set up this object to match given mlnorm
[625]452                void connect_mlnorm(const mlnorm<fsqmat> &ml){
453                        //get ids of yrv                               
[605]454                        const RV &yrv = ml._rv();
455                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
456                        RV rgr0 = ml._rvc().remove_time();
457                        RV urv = rgr0.subt(yrv); 
[586]458                       
459                        //We can do only 1d now... :(
[620]460                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
[605]461
462                        // create names for
[586]463                        RV xrv; //empty
464                        RV Crv; //empty
[605]465                        int td=ml._rvc().mint();
466                        // assuming strictly proper function!!!
[586]467                        for (int t=-1;t>=td;t--){
[605]468                                xrv.add(yrv.copy_t(t));
469                                Crv.add(urv.copy_t(t));
[586]470                        }
[679]471                                               
[605]472                        // get mapp
473                        th2A.set_connection(xrv, ml._rvc());
474                        th2C.set_connection(Crv, ml._rvc());
475                        th2D.set_connection(urv, ml._rvc());
476
477                        //set matrix sizes
[679]478                        this->A=zeros(xrv._dsize(),xrv._dsize());
479                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
480                                this->B=zeros(xrv._dsize(),1);
[605]481                                this->B(0) = 1.0;
[679]482                                this->C=zeros(1,xrv._dsize());
[605]483                                this->D=zeros(1,urv._dsize());
[679]484                                this->Q = zeros(xrv._dsize(),xrv._dsize());
[605]485                        // R is set by update
[586]486                       
[605]487                        //set cache
488                        this->A1row = zeros(xrv._dsize());
489                        this->C1row = zeros(xrv._dsize());
490                        this->D1row = zeros(urv._dsize());
491                       
492                        update_from(ml);
493                        validate();
494                };
495                //! fast function to update parameters from ml - not checked for compatibility!!
496                void update_from(const mlnorm<fsqmat> &ml){
497                       
[586]498                        vec theta = ml._A().get_row(0); // this
[605]499                       
500                        th2A.filldown(theta,A1row);
501                        th2C.filldown(theta,C1row);
502                        th2D.filldown(theta,D1row);
[586]503
[605]504                        R = ml._R();
505
[586]506                        A.set_row(0,A1row);
[605]507                        C.set_row(0,C1row+D1row*A1row);
508                        D.set_row(0,D1row);
509                       
510                }
511};
[586]512
[583]513/////////// INSTANTIATION
[357]514
[477]515template<class sq_T>
[583]516void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
517{
[28]518
[583]519        A = A0;
520        B = B0;
521        C = C0;
522        D = D0;
[477]523        R = R0;
524        Q = Q0;
[583]525        validate();
[477]526}
[22]527
[477]528template<class sq_T>
[583]529void StateSpace<sq_T>::validate(){
[679]530        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
531        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
532        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
533        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
534        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
535        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
[583]536}
[22]537
[254]538}
[7]539#endif // KF_H
540
Note: See TracBrowser for help on using the browser.