root/library/bdm/estim/kalman.h @ 675

Revision 675, 14.4 kB (checked in by mido, 15 years ago)

experiment: epdf as a descendat of mpdf

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
56                //!copy constructor
57                StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
58                //! set all matrix parameters
59                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
60                //! validation
61                void validate();
62                //! not virtual in this case
63                void from_setting (const Setting &set) {
64                        UI::get (A, set, "A", UI::compulsory);
65                        UI::get (B, set, "B", UI::compulsory);
66                        UI::get (C, set, "C", UI::compulsory);
67                        UI::get (D, set, "D", UI::compulsory);
68                        mat Qtm, Rtm;
69                        if(!UI::get(Qtm, set, "Q", UI::optional)){
70                                vec dq;
71                                UI::get(dq, set, "dQ", UI::compulsory);
72                                Qtm=diag(dq);
73                        }
74                        if(!UI::get(Rtm, set, "R", UI::optional)){
75                                vec dr;
76                                UI::get(dr, set, "dQ", UI::compulsory);
77                                Rtm=diag(dr);
78                        }
79                        R=Rtm; // automatic conversion
80                        Q=Qtm; 
81                       
82                        validate();
83                }               
84                //! access function
85                int _dimx(){return dimx;}
86                //! access function
87                int _dimy(){return dimy;}
88                //! access function
89                int _dimu(){return dimu;}
90                //! access function
91                const mat& _A() const {return A;}
92                //! access function
93                const mat& _B()const {return B;}
94                //! access function
95                const mat& _C()const {return C;}
96                //! access function
97                const mat& _D()const {return D;}
98                //! access function
99                const sq_T& _Q()const {return Q;}
100                //! access function
101                const sq_T& _R()const {return R;}
102};
103
104//! Common abstract base for Kalman filters
105template<class sq_T>
106class Kalman: public BM, public StateSpace<sq_T>
107{
108        protected:
109                //! id of output
110                RV yrv;
111                //! id of input
112                RV urv;
113                //! Kalman gain
114                mat  _K;
115                //!posterior
116                shared_ptr<enorm<sq_T> > est;
117                //!marginal on data f(y|y)
118                enorm<sq_T>  fy;
119        public:
120                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
121                //! Copy constructor
122                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K),  est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
123                //!set statistics of the posterior
124                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
125                //!set statistics of the posterior
126                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
127                //! return correctly typed posterior (covariant return)
128                const enorm<sq_T>& posterior() const {return *est.get();}
129                //! shared posterior
130                shared_ptr<epdf> shared_posterior() {return est;}
131                //! load basic elements of Kalman from structure
132                void from_setting (const Setting &set) {
133                        StateSpace<sq_T>::from_setting(set);
134                                               
135                        mat P0; vec mu0;
136                        UI::get(mu0, set, "mu0", UI::optional);
137                        UI::get(P0, set,  "P0", UI::optional);
138                        set_statistics(mu0,P0);
139                        // Initial values
140                        UI::get (yrv, set, "yrv", UI::optional);
141                        UI::get (urv, set, "urv", UI::optional);
142                        set_drv(concat(yrv,urv));
143                       
144                        validate();
145                }
146                //! validate object
147                void validate() {
148                        StateSpace<sq_T>::validate();
149                        bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
150                }
151};
152/*!
153* \brief Basic Kalman filter with full matrices
154*/
155
156class KalmanFull : public Kalman<fsqmat>
157{
158        public:
159                //! For EKFfull;
160                KalmanFull() :Kalman<fsqmat>(){};
161                //! Here dt = [yt;ut] of appropriate dimensions
162                void bayes (const vec &dt);
163                BM* _copy_() const {
164                        KalmanFull* K = new KalmanFull;
165                        K->set_parameters (A, B, C, D, Q, R);
166                        K->set_statistics (est->_mu(), est->_R());
167                        return K;
168                }
169};
170UIREGISTER(KalmanFull);
171
172
173/*! \brief Kalman filter in square root form
174
175Trivial example:
176\include kalman_simple.cpp
177
178Complete constructor:
179*/
180class KalmanCh : public Kalman<chmat>
181{
182        protected:
183                //! @{ \name Internal storage - needs initialize()
184                //! pre array (triangular matrix)
185                mat preA;
186                //! post array (triangular matrix)
187                mat postA;
188                //!@}
189        public:
190                //! copy constructor
191                BM* _copy_() const {
192                        KalmanCh* K = new KalmanCh;
193                        K->set_parameters (A, B, C, D, Q, R);
194                        K->set_statistics (est->_mu(), est->_R());
195                        return K;
196                }
197                //! set parameters for adapt from Kalman
198                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
199                //! initialize internal parametetrs
200                void initialize();
201
202                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
203
204                The following equality hold::\f[
205                \left[\begin{array}{cc}
206                R^{0.5}\\
207                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
208                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
209                R_{y}^{0.5} & KA'\\
210                & P_{t+1|t}^{0.5}\\
211                \\\end{array}\right]\f]
212
213                Thus this object evaluates only predictors! Not filtering densities.
214                */
215                void bayes (const vec &dt);
216
217                void from_setting(const Setting &set){
218                        Kalman<chmat>::from_setting(set);
219                        initialize();
220                }
221};
222UIREGISTER(KalmanCh);
223
224/*!
225\brief Extended Kalman Filter in full matrices
226
227An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
228*/
229class EKFfull : public KalmanFull
230{
231        protected:
232                //! Internal Model f(x,u)
233                shared_ptr<diffbifn> pfxu;
234
235                //! Observation Model h(x,u)
236                shared_ptr<diffbifn> phxu;
237
238        public:
239                //! Default constructor
240                EKFfull ();
241
242                //! Set nonlinear functions for mean values and covariance matrices.
243                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
244
245                //! Here dt = [yt;ut] of appropriate dimensions
246                void bayes (const vec &dt);
247                //! set estimates
248                void set_statistics (const vec &mu0, const mat &P0) {
249                        est->set_parameters (mu0, P0);
250                };
251                //! access function
252                const mat _R() {
253                        return est->_R().to_mat();
254                }
255                void from_setting (const Setting &set) {
256                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
257                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
258                       
259                        //statistics
260                        int dim = IM->dimension();
261                        vec mu0;
262                        if ( !UI::get ( mu0, set, "mu0" ) )
263                                mu0 = zeros ( dim );
264                       
265                        mat P0;
266                        vec dP0;
267                        if ( UI::get ( dP0, set, "dP0" ) )
268                                P0 = diag ( dP0 );
269                        else if ( !UI::get ( P0, set, "P0" ) )
270                                P0 = eye ( dim );
271                       
272                        set_statistics ( mu0, P0 );
273                       
274                        //parameters
275                        vec dQ, dR;
276                        UI::get ( dQ, set, "dQ", UI::compulsory );
277                        UI::get ( dR, set, "dR", UI::compulsory );
278                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
279                       
280                        //connect
281                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
282                        set_drv ( *drv );
283                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
284                        set_rv ( *rv );
285                       
286                        string options;
287                        if ( UI::get ( options, set, "options" ) )
288                                set_options ( options );
289//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
290//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
291//                     
292//                      mat R0;
293//                      UI::get(R0, set, "R",UI::compulsory);
294//                      mat Q0;
295//                      UI::get(Q0, set, "Q",UI::compulsory);
296//                     
297//                     
298//                      mat P0; vec mu0;
299//                      UI::get(mu0, set, "mu0", UI::optional);
300//                      UI::get(P0, set,  "P0", UI::optional);
301//                      set_statistics(mu0,P0);
302//                      // Initial values
303//                      UI::get (yrv, set, "yrv", UI::optional);
304//                      UI::get (urv, set, "urv", UI::optional);
305//                      set_drv(concat(yrv,urv));
306//
307//                      // setup StateSpace
308//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
309//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
310//                     
311                        validate();
312                }
313                void validate() {
314                        // check stats and IM and OM
315                }
316};
317UIREGISTER(EKFfull);
318
319
320/*!
321\brief Extended Kalman Filter in Square root
322
323An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
324*/
325
326class EKFCh : public KalmanCh
327{
328        protected:
329                //! Internal Model f(x,u)
330                shared_ptr<diffbifn> pfxu;
331
332                //! Observation Model h(x,u)
333                shared_ptr<diffbifn> phxu;
334        public:
335                //! copy constructor duplicated - calls different set_parameters
336                BM* _copy_() const {
337                        EKFCh* E = new EKFCh;
338                        E->set_parameters (pfxu, phxu, Q, R);
339                        E->set_statistics (est->_mu(), est->_R());
340                        return E;
341                }
342                //! Set nonlinear functions for mean values and covariance matrices.
343                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
344
345                //! Here dt = [yt;ut] of appropriate dimensions
346                void bayes (const vec &dt);
347
348                void from_setting (const Setting &set);
349
350                // TODO dodelat void to_setting( Setting &set ) const;
351
352};
353
354UIREGISTER (EKFCh);
355SHAREDPTR (EKFCh);
356
357
358//////// INstance
359
360/*! \brief (Switching) Multiple Model
361The model runs several models in parallel and evaluates thier weights (fittness).
362
363The statistics of the resulting density are merged using (geometric?) combination.
364
365The next step is performed with the new statistics for all models.
366*/
367class MultiModel: public BM
368{
369        protected:
370                //! List of models between which we switch
371                Array<EKFCh*> Models;
372                //! vector of model weights
373                vec w;
374                //! cache of model lls
375                vec _lls;
376                //! type of switching policy [1=maximum,2=...]
377                int policy;
378                //! internal statistics
379                enorm<chmat> est;
380        public:
381                //! set internal parameters
382                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
383                        Models = A;//TODO: test if evalll is set
384                        w.set_length (A.length());
385                        _lls.set_length (A.length());
386                        policy = pol0;
387
388                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
389                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
390                }
391                void bayes (const vec &dt) {
392                        int n = Models.length();
393                        int i;
394                        for (i = 0; i < n; i++) {
395                                Models (i)->bayes (dt);
396                                _lls (i) = Models (i)->_ll();
397                        }
398                        double mlls = max (_lls);
399                        w = exp (_lls - mlls);
400                        w /= sum (w);   //normalization
401                        //set statistics
402                        switch (policy) {
403                                case 1: {
404                                        int mi = max_index (w);
405                                        const enorm<chmat> &st = Models (mi)->posterior() ;
406                                        est.set_parameters (st.mean(), st._R());
407                                }
408                                break;
409                                default:
410                                        bdm_error ("unknown policy");
411                        }
412                        // copy result to all models
413                        for (i = 0; i < n; i++) {
414                                Models (i)->set_statistics (est.mean(), est._R());
415                        }
416                }
417                //! return correctly typed posterior (covariant return)
418                const enorm<chmat>& posterior() const {
419                        return est;
420                }
421
422                void from_setting (const Setting &set);
423
424                // TODO dodelat void to_setting( Setting &set ) const;
425
426};
427
428UIREGISTER (MultiModel);
429SHAREDPTR (MultiModel);
430
431//! conversion of outer ARX model (mlnorm) to state space model
432/*!
433The model is constructed as:
434\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
435For example, for:
436Using Frobenius form, see [].
437
438For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
439*/
440//template<class sq_T>
441class StateCanonical: public StateSpace<fsqmat>{
442        protected:
443                //! remember connection from theta ->A
444                datalink_part th2A;
445                //! remember connection from theta ->C
446                datalink_part th2C;
447                //! remember connection from theta ->D
448                datalink_part th2D;
449                //!cached first row of A
450                vec A1row;
451                //!cached first row of C
452                vec C1row;
453                //!cached first row of D
454                vec D1row;
455               
456        public:
457                //! set up this object to match given mlnorm
458                void connect_mlnorm(const mlnorm<fsqmat> &ml){
459                        //get ids of yrv                               
460                        const RV &yrv = ml._rv();
461                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
462                        RV rgr0 = ml._rvc().remove_time();
463                        RV urv = rgr0.subt(yrv); 
464                       
465                        //We can do only 1d now... :(
466                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
467
468                        // create names for
469                        RV xrv; //empty
470                        RV Crv; //empty
471                        int td=ml._rvc().mint();
472                        // assuming strictly proper function!!!
473                        for (int t=-1;t>=td;t--){
474                                xrv.add(yrv.copy_t(t));
475                                Crv.add(urv.copy_t(t));
476                        }
477                       
478                        this->dimx = xrv._dsize();
479                        this->dimy = yrv._dsize();
480                        this->dimu = urv._dsize();
481                       
482                        // get mapp
483                        th2A.set_connection(xrv, ml._rvc());
484                        th2C.set_connection(Crv, ml._rvc());
485                        th2D.set_connection(urv, ml._rvc());
486
487                        //set matrix sizes
488                        this->A=zeros(dimx,dimx);
489                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
490                                this->B=zeros(dimx,1);
491                                this->B(0) = 1.0;
492                                this->C=zeros(1,dimx);
493                                this->D=zeros(1,urv._dsize());
494                                this->Q = zeros(dimx,dimx);
495                        // R is set by update
496                       
497                        //set cache
498                        this->A1row = zeros(xrv._dsize());
499                        this->C1row = zeros(xrv._dsize());
500                        this->D1row = zeros(urv._dsize());
501                       
502                        update_from(ml);
503                        validate();
504                };
505                //! fast function to update parameters from ml - not checked for compatibility!!
506                void update_from(const mlnorm<fsqmat> &ml){
507                       
508                        vec theta = ml._A().get_row(0); // this
509                       
510                        th2A.filldown(theta,A1row);
511                        th2C.filldown(theta,C1row);
512                        th2D.filldown(theta,D1row);
513
514                        R = ml._R();
515
516                        A.set_row(0,A1row);
517                        C.set_row(0,C1row+D1row*A1row);
518                        D.set_row(0,D1row);
519                       
520                }
521};
522
523/////////// INSTANTIATION
524
525template<class sq_T>
526void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
527{
528
529        A = A0;
530        B = B0;
531        C = C0;
532        D = D0;
533        R = R0;
534        Q = Q0;
535        validate();
536}
537
538template<class sq_T>
539void StateSpace<sq_T>::validate(){
540        dimx = A.rows();
541        dimu = B.cols();
542        dimy = C.rows();
543        bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
544        bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
545        bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
546        bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
547        bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
548        bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
549}
550
551}
552#endif // KF_H
553
Note: See TracBrowser for help on using the browser.