root/library/bdm/estim/kalman.h @ 675

Revision 675, 14.4 kB (checked in by mido, 15 years ago)

experiment: epdf as a descendat of mpdf

  • Property svn:eol-style set to native
RevLine 
[7]1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
[262]16
[384]17#include "../math/functions.h"
18#include "../stat/exp_family.h"
[37]19#include "../math/chmat.h"
[384]20#include "../base/user_info.h"
[7]21
[583]22namespace bdm
23{
[7]24
[477]25/*!
[583]26 * \brief Basic elements of linear state-space model
[32]27
[477]28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
[583]31 */
[477]32template<class sq_T>
[583]33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
[660]56                //!copy constructor
[653]57                StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
[660]58                //! set all matrix parameters
[583]59                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
[660]60                //! validation
[583]61                void validate();
62                //! not virtual in this case
63                void from_setting (const Setting &set) {
64                        UI::get (A, set, "A", UI::compulsory);
65                        UI::get (B, set, "B", UI::compulsory);
66                        UI::get (C, set, "C", UI::compulsory);
67                        UI::get (D, set, "D", UI::compulsory);
68                        mat Qtm, Rtm;
69                        if(!UI::get(Qtm, set, "Q", UI::optional)){
70                                vec dq;
71                                UI::get(dq, set, "dQ", UI::compulsory);
72                                Qtm=diag(dq);
73                        }
74                        if(!UI::get(Rtm, set, "R", UI::optional)){
75                                vec dr;
76                                UI::get(dr, set, "dQ", UI::compulsory);
77                                Rtm=diag(dr);
78                        }
79                        R=Rtm; // automatic conversion
80                        Q=Qtm; 
81                       
82                        validate();
83                }               
[586]84                //! access function
85                int _dimx(){return dimx;}
86                //! access function
87                int _dimy(){return dimy;}
88                //! access function
89                int _dimu(){return dimu;}
90                //! access function
[605]91                const mat& _A() const {return A;}
[586]92                //! access function
[605]93                const mat& _B()const {return B;}
[586]94                //! access function
[605]95                const mat& _C()const {return C;}
[586]96                //! access function
[605]97                const mat& _D()const {return D;}
[586]98                //! access function
[605]99                const sq_T& _Q()const {return Q;}
[586]100                //! access function
[605]101                const sq_T& _R()const {return R;}
[583]102};
[32]103
[583]104//! Common abstract base for Kalman filters
105template<class sq_T>
106class Kalman: public BM, public StateSpace<sq_T>
107{
108        protected:
109                //! id of output
110                RV yrv;
111                //! id of input
112                RV urv;
113                //! Kalman gain
114                mat  _K;
115                //!posterior
116                shared_ptr<enorm<sq_T> > est;
117                //!marginal on data f(y|y)
118                enorm<sq_T>  fy;
119        public:
[653]120                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
[660]121                //! Copy constructor
[653]122                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K),  est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
[660]123                //!set statistics of the posterior
[583]124                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
[660]125                //!set statistics of the posterior
[583]126                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
[660]127                //! return correctly typed posterior (covariant return)
[583]128                const enorm<sq_T>& posterior() const {return *est.get();}
129                //! shared posterior
130                shared_ptr<epdf> shared_posterior() {return est;}
131                //! load basic elements of Kalman from structure
132                void from_setting (const Setting &set) {
133                        StateSpace<sq_T>::from_setting(set);
134                                               
135                        mat P0; vec mu0;
136                        UI::get(mu0, set, "mu0", UI::optional);
137                        UI::get(P0, set,  "P0", UI::optional);
138                        set_statistics(mu0,P0);
139                        // Initial values
140                        UI::get (yrv, set, "yrv", UI::optional);
141                        UI::get (urv, set, "urv", UI::optional);
142                        set_drv(concat(yrv,urv));
143                       
144                        validate();
145                }
[660]146                //! validate object
[583]147                void validate() {
148                        StateSpace<sq_T>::validate();
[620]149                        bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
[583]150                }
151};
152/*!
153* \brief Basic Kalman filter with full matrices
154*/
[32]155
[583]156class KalmanFull : public Kalman<fsqmat>
157{
158        public:
159                //! For EKFfull;
160                KalmanFull() :Kalman<fsqmat>(){};
161                //! Here dt = [yt;ut] of appropriate dimensions
162                void bayes (const vec &dt);
[653]163                BM* _copy_() const {
164                        KalmanFull* K = new KalmanFull;
165                        K->set_parameters (A, B, C, D, Q, R);
166                        K->set_statistics (est->_mu(), est->_R());
167                        return K;
168                }
[583]169};
[586]170UIREGISTER(KalmanFull);
[32]171
172
[477]173/*! \brief Kalman filter in square root form
[271]174
[477]175Trivial example:
176\include kalman_simple.cpp
177
[583]178Complete constructor:
[477]179*/
[583]180class KalmanCh : public Kalman<chmat>
181{
182        protected:
183                //! @{ \name Internal storage - needs initialize()
184                //! pre array (triangular matrix)
185                mat preA;
186                //! post array (triangular matrix)
187                mat postA;
188                //!@}
189        public:
190                //! copy constructor
191                BM* _copy_() const {
192                        KalmanCh* K = new KalmanCh;
193                        K->set_parameters (A, B, C, D, Q, R);
194                        K->set_statistics (est->_mu(), est->_R());
195                        return K;
196                }
197                //! set parameters for adapt from Kalman
198                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
199                //! initialize internal parametetrs
200                void initialize();
[37]201
[583]202                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
[283]203
[583]204                The following equality hold::\f[
205                \left[\begin{array}{cc}
206                R^{0.5}\\
207                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
208                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
209                R_{y}^{0.5} & KA'\\
210                & P_{t+1|t}^{0.5}\\
211                \\\end{array}\right]\f]
[283]212
[583]213                Thus this object evaluates only predictors! Not filtering densities.
214                */
215                void bayes (const vec &dt);
[675]216
[586]217                void from_setting(const Setting &set){
[583]218                        Kalman<chmat>::from_setting(set);
219                        initialize();
220                }
[477]221};
[586]222UIREGISTER(KalmanCh);
[37]223
[477]224/*!
225\brief Extended Kalman Filter in full matrices
[62]226
[477]227An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
228*/
[583]229class EKFfull : public KalmanFull
230{
231        protected:
232                //! Internal Model f(x,u)
233                shared_ptr<diffbifn> pfxu;
[527]234
[583]235                //! Observation Model h(x,u)
236                shared_ptr<diffbifn> phxu;
[283]237
[583]238        public:
239                //! Default constructor
240                EKFfull ();
[527]241
[583]242                //! Set nonlinear functions for mean values and covariance matrices.
243                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
[527]244
[583]245                //! Here dt = [yt;ut] of appropriate dimensions
246                void bayes (const vec &dt);
247                //! set estimates
248                void set_statistics (const vec &mu0, const mat &P0) {
249                        est->set_parameters (mu0, P0);
250                };
[660]251                //! access function
[583]252                const mat _R() {
253                        return est->_R().to_mat();
254                }
[653]255                void from_setting (const Setting &set) {
256                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
257                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
258                       
259                        //statistics
260                        int dim = IM->dimension();
261                        vec mu0;
262                        if ( !UI::get ( mu0, set, "mu0" ) )
263                                mu0 = zeros ( dim );
264                       
265                        mat P0;
266                        vec dP0;
267                        if ( UI::get ( dP0, set, "dP0" ) )
268                                P0 = diag ( dP0 );
269                        else if ( !UI::get ( P0, set, "P0" ) )
270                                P0 = eye ( dim );
271                       
272                        set_statistics ( mu0, P0 );
273                       
274                        //parameters
275                        vec dQ, dR;
276                        UI::get ( dQ, set, "dQ", UI::compulsory );
277                        UI::get ( dR, set, "dR", UI::compulsory );
278                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
279                       
280                        //connect
281                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
282                        set_drv ( *drv );
283                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
284                        set_rv ( *rv );
285                       
286                        string options;
287                        if ( UI::get ( options, set, "options" ) )
288                                set_options ( options );
289//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
290//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
291//                     
292//                      mat R0;
293//                      UI::get(R0, set, "R",UI::compulsory);
294//                      mat Q0;
295//                      UI::get(Q0, set, "Q",UI::compulsory);
296//                     
297//                     
298//                      mat P0; vec mu0;
299//                      UI::get(mu0, set, "mu0", UI::optional);
300//                      UI::get(P0, set,  "P0", UI::optional);
301//                      set_statistics(mu0,P0);
302//                      // Initial values
303//                      UI::get (yrv, set, "yrv", UI::optional);
304//                      UI::get (urv, set, "urv", UI::optional);
305//                      set_drv(concat(yrv,urv));
306//
307//                      // setup StateSpace
308//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
309//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
310//                     
311                        validate();
312                }
313                void validate() {
314                        // check stats and IM and OM
315                }
[477]316};
[586]317UIREGISTER(EKFfull);
[62]318
[586]319
[477]320/*!
321\brief Extended Kalman Filter in Square root
[37]322
[477]323An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
324*/
[37]325
[583]326class EKFCh : public KalmanCh
327{
328        protected:
329                //! Internal Model f(x,u)
330                shared_ptr<diffbifn> pfxu;
[527]331
[583]332                //! Observation Model h(x,u)
333                shared_ptr<diffbifn> phxu;
334        public:
335                //! copy constructor duplicated - calls different set_parameters
336                BM* _copy_() const {
337                        EKFCh* E = new EKFCh;
338                        E->set_parameters (pfxu, phxu, Q, R);
339                        E->set_statistics (est->_mu(), est->_R());
340                        return E;
341                }
342                //! Set nonlinear functions for mean values and covariance matrices.
343                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
[527]344
[583]345                //! Here dt = [yt;ut] of appropriate dimensions
346                void bayes (const vec &dt);
[357]347
[583]348                void from_setting (const Setting &set);
[357]349
[583]350                // TODO dodelat void to_setting( Setting &set ) const;
[357]351
[477]352};
[37]353
[583]354UIREGISTER (EKFCh);
355SHAREDPTR (EKFCh);
[357]356
357
[7]358//////// INstance
359
[477]360/*! \brief (Switching) Multiple Model
361The model runs several models in parallel and evaluates thier weights (fittness).
[62]362
[477]363The statistics of the resulting density are merged using (geometric?) combination.
[283]364
[477]365The next step is performed with the new statistics for all models.
366*/
[583]367class MultiModel: public BM
368{
369        protected:
370                //! List of models between which we switch
371                Array<EKFCh*> Models;
372                //! vector of model weights
373                vec w;
374                //! cache of model lls
375                vec _lls;
376                //! type of switching policy [1=maximum,2=...]
377                int policy;
378                //! internal statistics
379                enorm<chmat> est;
380        public:
[660]381                //! set internal parameters
[583]382                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
383                        Models = A;//TODO: test if evalll is set
384                        w.set_length (A.length());
385                        _lls.set_length (A.length());
386                        policy = pol0;
[357]387
[583]388                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
389                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
[477]390                }
[583]391                void bayes (const vec &dt) {
392                        int n = Models.length();
393                        int i;
394                        for (i = 0; i < n; i++) {
395                                Models (i)->bayes (dt);
396                                _lls (i) = Models (i)->_ll();
397                        }
398                        double mlls = max (_lls);
399                        w = exp (_lls - mlls);
400                        w /= sum (w);   //normalization
401                        //set statistics
402                        switch (policy) {
403                                case 1: {
404                                        int mi = max_index (w);
405                                        const enorm<chmat> &st = Models (mi)->posterior() ;
406                                        est.set_parameters (st.mean(), st._R());
407                                }
408                                break;
409                                default:
410                                        bdm_error ("unknown policy");
411                        }
412                        // copy result to all models
413                        for (i = 0; i < n; i++) {
414                                Models (i)->set_statistics (est.mean(), est._R());
415                        }
[477]416                }
[660]417                //! return correctly typed posterior (covariant return)
[583]418                const enorm<chmat>& posterior() const {
419                        return est;
[477]420                }
[357]421
[583]422                void from_setting (const Setting &set);
[357]423
[583]424                // TODO dodelat void to_setting( Setting &set ) const;
[338]425
[477]426};
[338]427
[583]428UIREGISTER (MultiModel);
429SHAREDPTR (MultiModel);
[357]430
[586]431//! conversion of outer ARX model (mlnorm) to state space model
432/*!
433The model is constructed as:
434\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
435For example, for:
[605]436Using Frobenius form, see [].
437
438For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
[586]439*/
[605]440//template<class sq_T>
441class StateCanonical: public StateSpace<fsqmat>{
442        protected:
443                //! remember connection from theta ->A
444                datalink_part th2A;
445                //! remember connection from theta ->C
446                datalink_part th2C;
447                //! remember connection from theta ->D
448                datalink_part th2D;
449                //!cached first row of A
450                vec A1row;
451                //!cached first row of C
452                vec C1row;
453                //!cached first row of D
454                vec D1row;
455               
456        public:
457                //! set up this object to match given mlnorm
[625]458                void connect_mlnorm(const mlnorm<fsqmat> &ml){
459                        //get ids of yrv                               
[605]460                        const RV &yrv = ml._rv();
461                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
462                        RV rgr0 = ml._rvc().remove_time();
463                        RV urv = rgr0.subt(yrv); 
[586]464                       
465                        //We can do only 1d now... :(
[620]466                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
[605]467
468                        // create names for
[586]469                        RV xrv; //empty
470                        RV Crv; //empty
[605]471                        int td=ml._rvc().mint();
472                        // assuming strictly proper function!!!
[586]473                        for (int t=-1;t>=td;t--){
[605]474                                xrv.add(yrv.copy_t(t));
475                                Crv.add(urv.copy_t(t));
[586]476                        }
477                       
[605]478                        this->dimx = xrv._dsize();
479                        this->dimy = yrv._dsize();
480                        this->dimu = urv._dsize();
[586]481                       
[605]482                        // get mapp
483                        th2A.set_connection(xrv, ml._rvc());
484                        th2C.set_connection(Crv, ml._rvc());
485                        th2D.set_connection(urv, ml._rvc());
486
487                        //set matrix sizes
488                        this->A=zeros(dimx,dimx);
489                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
490                                this->B=zeros(dimx,1);
491                                this->B(0) = 1.0;
492                                this->C=zeros(1,dimx);
493                                this->D=zeros(1,urv._dsize());
494                                this->Q = zeros(dimx,dimx);
495                        // R is set by update
[586]496                       
[605]497                        //set cache
498                        this->A1row = zeros(xrv._dsize());
499                        this->C1row = zeros(xrv._dsize());
500                        this->D1row = zeros(urv._dsize());
501                       
502                        update_from(ml);
503                        validate();
504                };
505                //! fast function to update parameters from ml - not checked for compatibility!!
506                void update_from(const mlnorm<fsqmat> &ml){
507                       
[586]508                        vec theta = ml._A().get_row(0); // this
[605]509                       
510                        th2A.filldown(theta,A1row);
511                        th2C.filldown(theta,C1row);
512                        th2D.filldown(theta,D1row);
[586]513
[605]514                        R = ml._R();
515
[586]516                        A.set_row(0,A1row);
[605]517                        C.set_row(0,C1row+D1row*A1row);
518                        D.set_row(0,D1row);
519                       
520                }
521};
[586]522
[583]523/////////// INSTANTIATION
[357]524
[477]525template<class sq_T>
[583]526void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
527{
[28]528
[583]529        A = A0;
530        B = B0;
531        C = C0;
532        D = D0;
[477]533        R = R0;
534        Q = Q0;
[583]535        validate();
[477]536}
[22]537
[477]538template<class sq_T>
[583]539void StateSpace<sq_T>::validate(){
[653]540        dimx = A.rows();
541        dimu = B.cols();
542        dimy = C.rows();
[620]543        bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
544        bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
545        bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
[653]546        bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
547        bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
548        bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
[583]549}
[22]550
[254]551}
[7]552#endif // KF_H
553
Note: See TracBrowser for help on using the browser.