root/library/bdm/estim/kalman.h @ 653

Revision 653, 14.1 kB (checked in by smidl, 15 years ago)

corrections in Kalman and particles

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
56                StateSpace(const StateSpace<sq_T> &S0) : dimx (S0.dimx), dimy (S0.dimy), dimu (S0.dimu), A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
57                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
58                void validate();
59                //! not virtual in this case
60                void from_setting (const Setting &set) {
61                        UI::get (A, set, "A", UI::compulsory);
62                        UI::get (B, set, "B", UI::compulsory);
63                        UI::get (C, set, "C", UI::compulsory);
64                        UI::get (D, set, "D", UI::compulsory);
65                        mat Qtm, Rtm;
66                        if(!UI::get(Qtm, set, "Q", UI::optional)){
67                                vec dq;
68                                UI::get(dq, set, "dQ", UI::compulsory);
69                                Qtm=diag(dq);
70                        }
71                        if(!UI::get(Rtm, set, "R", UI::optional)){
72                                vec dr;
73                                UI::get(dr, set, "dQ", UI::compulsory);
74                                Rtm=diag(dr);
75                        }
76                        R=Rtm; // automatic conversion
77                        Q=Qtm; 
78                       
79                        validate();
80                }               
81                //! access function
82                int _dimx(){return dimx;}
83                //! access function
84                int _dimy(){return dimy;}
85                //! access function
86                int _dimu(){return dimu;}
87                //! access function
88                const mat& _A() const {return A;}
89                //! access function
90                const mat& _B()const {return B;}
91                //! access function
92                const mat& _C()const {return C;}
93                //! access function
94                const mat& _D()const {return D;}
95                //! access function
96                const sq_T& _Q()const {return Q;}
97                //! access function
98                const sq_T& _R()const {return R;}
99};
100
101//! Common abstract base for Kalman filters
102template<class sq_T>
103class Kalman: public BM, public StateSpace<sq_T>
104{
105        protected:
106                //! id of output
107                RV yrv;
108                //! id of input
109                RV urv;
110                //! Kalman gain
111                mat  _K;
112                //!posterior
113                shared_ptr<enorm<sq_T> > est;
114                //!marginal on data f(y|y)
115                enorm<sq_T>  fy;
116        public:
117                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
118                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv),urv(K0.urv), _K(K0._K),  est(new enorm<sq_T>(*K0.est)), fy(K0.fy){}
119                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
120                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
121                //! posterior
122                const enorm<sq_T>& posterior() const {return *est.get();}
123                //! shared posterior
124                shared_ptr<epdf> shared_posterior() {return est;}
125                //! load basic elements of Kalman from structure
126                void from_setting (const Setting &set) {
127                        StateSpace<sq_T>::from_setting(set);
128                                               
129                        mat P0; vec mu0;
130                        UI::get(mu0, set, "mu0", UI::optional);
131                        UI::get(P0, set,  "P0", UI::optional);
132                        set_statistics(mu0,P0);
133                        // Initial values
134                        UI::get (yrv, set, "yrv", UI::optional);
135                        UI::get (urv, set, "urv", UI::optional);
136                        set_drv(concat(yrv,urv));
137                       
138                        validate();
139                }
140                void validate() {
141                        StateSpace<sq_T>::validate();
142                        bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
143                }
144};
145/*!
146* \brief Basic Kalman filter with full matrices
147*/
148
149class KalmanFull : public Kalman<fsqmat>
150{
151        public:
152                //! For EKFfull;
153                KalmanFull() :Kalman<fsqmat>(){};
154                //! Here dt = [yt;ut] of appropriate dimensions
155                void bayes (const vec &dt);
156                BM* _copy_() const {
157                        KalmanFull* K = new KalmanFull;
158                        K->set_parameters (A, B, C, D, Q, R);
159                        K->set_statistics (est->_mu(), est->_R());
160                        return K;
161                }
162};
163UIREGISTER(KalmanFull);
164
165
166/*! \brief Kalman filter in square root form
167
168Trivial example:
169\include kalman_simple.cpp
170
171Complete constructor:
172*/
173class KalmanCh : public Kalman<chmat>
174{
175        protected:
176                //! @{ \name Internal storage - needs initialize()
177                //! pre array (triangular matrix)
178                mat preA;
179                //! post array (triangular matrix)
180                mat postA;
181                //!@}
182        public:
183                //! copy constructor
184                BM* _copy_() const {
185                        KalmanCh* K = new KalmanCh;
186                        K->set_parameters (A, B, C, D, Q, R);
187                        K->set_statistics (est->_mu(), est->_R());
188                        return K;
189                }
190                //! set parameters for adapt from Kalman
191                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
192                //! initialize internal parametetrs
193                void initialize();
194
195                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
196
197                The following equality hold::\f[
198                \left[\begin{array}{cc}
199                R^{0.5}\\
200                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
201                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
202                R_{y}^{0.5} & KA'\\
203                & P_{t+1|t}^{0.5}\\
204                \\\end{array}\right]\f]
205
206                Thus this object evaluates only predictors! Not filtering densities.
207                */
208                void bayes (const vec &dt);
209               
210                void from_setting(const Setting &set){
211                        Kalman<chmat>::from_setting(set);
212                        initialize();
213                }
214};
215UIREGISTER(KalmanCh);
216
217/*!
218\brief Extended Kalman Filter in full matrices
219
220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
227
228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
230
231        public:
232                //! Default constructor
233                EKFfull ();
234
235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
237
238                //! Here dt = [yt;ut] of appropriate dimensions
239                void bayes (const vec &dt);
240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
242                        est->set_parameters (mu0, P0);
243                };
244                const mat _R() {
245                        return est->_R().to_mat();
246                }
247                void from_setting (const Setting &set) {
248                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
249                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
250                       
251                        //statistics
252                        int dim = IM->dimension();
253                        vec mu0;
254                        if ( !UI::get ( mu0, set, "mu0" ) )
255                                mu0 = zeros ( dim );
256                       
257                        mat P0;
258                        vec dP0;
259                        if ( UI::get ( dP0, set, "dP0" ) )
260                                P0 = diag ( dP0 );
261                        else if ( !UI::get ( P0, set, "P0" ) )
262                                P0 = eye ( dim );
263                       
264                        set_statistics ( mu0, P0 );
265                       
266                        //parameters
267                        vec dQ, dR;
268                        UI::get ( dQ, set, "dQ", UI::compulsory );
269                        UI::get ( dR, set, "dR", UI::compulsory );
270                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
271                       
272                        //connect
273                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
274                        set_drv ( *drv );
275                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
276                        set_rv ( *rv );
277                       
278                        string options;
279                        if ( UI::get ( options, set, "options" ) )
280                                set_options ( options );
281//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
282//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
283//                     
284//                      mat R0;
285//                      UI::get(R0, set, "R",UI::compulsory);
286//                      mat Q0;
287//                      UI::get(Q0, set, "Q",UI::compulsory);
288//                     
289//                     
290//                      mat P0; vec mu0;
291//                      UI::get(mu0, set, "mu0", UI::optional);
292//                      UI::get(P0, set,  "P0", UI::optional);
293//                      set_statistics(mu0,P0);
294//                      // Initial values
295//                      UI::get (yrv, set, "yrv", UI::optional);
296//                      UI::get (urv, set, "urv", UI::optional);
297//                      set_drv(concat(yrv,urv));
298//
299//                      // setup StateSpace
300//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
301//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
302//                     
303                        validate();
304                }
305                void validate() {
306                        // check stats and IM and OM
307                }
308};
309UIREGISTER(EKFfull);
310
311
312/*!
313\brief Extended Kalman Filter in Square root
314
315An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
316*/
317
318class EKFCh : public KalmanCh
319{
320        protected:
321                //! Internal Model f(x,u)
322                shared_ptr<diffbifn> pfxu;
323
324                //! Observation Model h(x,u)
325                shared_ptr<diffbifn> phxu;
326        public:
327                //! copy constructor duplicated - calls different set_parameters
328                BM* _copy_() const {
329                        EKFCh* E = new EKFCh;
330                        E->set_parameters (pfxu, phxu, Q, R);
331                        E->set_statistics (est->_mu(), est->_R());
332                        return E;
333                }
334                //! Set nonlinear functions for mean values and covariance matrices.
335                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
336
337                //! Here dt = [yt;ut] of appropriate dimensions
338                void bayes (const vec &dt);
339
340                void from_setting (const Setting &set);
341
342                // TODO dodelat void to_setting( Setting &set ) const;
343
344};
345
346UIREGISTER (EKFCh);
347SHAREDPTR (EKFCh);
348
349
350//////// INstance
351
352/*! \brief (Switching) Multiple Model
353The model runs several models in parallel and evaluates thier weights (fittness).
354
355The statistics of the resulting density are merged using (geometric?) combination.
356
357The next step is performed with the new statistics for all models.
358*/
359class MultiModel: public BM
360{
361        protected:
362                //! List of models between which we switch
363                Array<EKFCh*> Models;
364                //! vector of model weights
365                vec w;
366                //! cache of model lls
367                vec _lls;
368                //! type of switching policy [1=maximum,2=...]
369                int policy;
370                //! internal statistics
371                enorm<chmat> est;
372        public:
373                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
374                        Models = A;//TODO: test if evalll is set
375                        w.set_length (A.length());
376                        _lls.set_length (A.length());
377                        policy = pol0;
378
379                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
380                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
381                }
382                void bayes (const vec &dt) {
383                        int n = Models.length();
384                        int i;
385                        for (i = 0; i < n; i++) {
386                                Models (i)->bayes (dt);
387                                _lls (i) = Models (i)->_ll();
388                        }
389                        double mlls = max (_lls);
390                        w = exp (_lls - mlls);
391                        w /= sum (w);   //normalization
392                        //set statistics
393                        switch (policy) {
394                                case 1: {
395                                        int mi = max_index (w);
396                                        const enorm<chmat> &st = Models (mi)->posterior() ;
397                                        est.set_parameters (st.mean(), st._R());
398                                }
399                                break;
400                                default:
401                                        bdm_error ("unknown policy");
402                        }
403                        // copy result to all models
404                        for (i = 0; i < n; i++) {
405                                Models (i)->set_statistics (est.mean(), est._R());
406                        }
407                }
408                //! posterior density
409                const enorm<chmat>& posterior() const {
410                        return est;
411                }
412
413                void from_setting (const Setting &set);
414
415                // TODO dodelat void to_setting( Setting &set ) const;
416
417};
418
419UIREGISTER (MultiModel);
420SHAREDPTR (MultiModel);
421
422//! conversion of outer ARX model (mlnorm) to state space model
423/*!
424The model is constructed as:
425\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
426For example, for:
427Using Frobenius form, see [].
428
429For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
430*/
431//template<class sq_T>
432class StateCanonical: public StateSpace<fsqmat>{
433        protected:
434                //! remember connection from theta ->A
435                datalink_part th2A;
436                //! remember connection from theta ->C
437                datalink_part th2C;
438                //! remember connection from theta ->D
439                datalink_part th2D;
440                //!cached first row of A
441                vec A1row;
442                //!cached first row of C
443                vec C1row;
444                //!cached first row of D
445                vec D1row;
446               
447        public:
448                //! set up this object to match given mlnorm
449                void connect_mlnorm(const mlnorm<fsqmat> &ml){
450                        //get ids of yrv                               
451                        const RV &yrv = ml._rv();
452                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
453                        RV rgr0 = ml._rvc().remove_time();
454                        RV urv = rgr0.subt(yrv); 
455                       
456                        //We can do only 1d now... :(
457                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
458
459                        // create names for
460                        RV xrv; //empty
461                        RV Crv; //empty
462                        int td=ml._rvc().mint();
463                        // assuming strictly proper function!!!
464                        for (int t=-1;t>=td;t--){
465                                xrv.add(yrv.copy_t(t));
466                                Crv.add(urv.copy_t(t));
467                        }
468                       
469                        this->dimx = xrv._dsize();
470                        this->dimy = yrv._dsize();
471                        this->dimu = urv._dsize();
472                       
473                        // get mapp
474                        th2A.set_connection(xrv, ml._rvc());
475                        th2C.set_connection(Crv, ml._rvc());
476                        th2D.set_connection(urv, ml._rvc());
477
478                        //set matrix sizes
479                        this->A=zeros(dimx,dimx);
480                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
481                                this->B=zeros(dimx,1);
482                                this->B(0) = 1.0;
483                                this->C=zeros(1,dimx);
484                                this->D=zeros(1,urv._dsize());
485                                this->Q = zeros(dimx,dimx);
486                        // R is set by update
487                       
488                        //set cache
489                        this->A1row = zeros(xrv._dsize());
490                        this->C1row = zeros(xrv._dsize());
491                        this->D1row = zeros(urv._dsize());
492                       
493                        update_from(ml);
494                        validate();
495                };
496                //! fast function to update parameters from ml - not checked for compatibility!!
497                void update_from(const mlnorm<fsqmat> &ml){
498                       
499                        vec theta = ml._A().get_row(0); // this
500                       
501                        th2A.filldown(theta,A1row);
502                        th2C.filldown(theta,C1row);
503                        th2D.filldown(theta,D1row);
504
505                        R = ml._R();
506
507                        A.set_row(0,A1row);
508                        C.set_row(0,C1row+D1row*A1row);
509                        D.set_row(0,D1row);
510                       
511                }
512};
513
514/////////// INSTANTIATION
515
516template<class sq_T>
517void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
518{
519
520        A = A0;
521        B = B0;
522        C = C0;
523        D = D0;
524        R = R0;
525        Q = Q0;
526        validate();
527}
528
529template<class sq_T>
530void StateSpace<sq_T>::validate(){
531        dimx = A.rows();
532        dimu = B.cols();
533        dimy = C.rows();
534        bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
535        bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
536        bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
537        bdm_assert ( (D.rows() == dimy) && (D.cols() == dimu), "KalmanFull: D is not compatible");
538        bdm_assert ( (Q.cols() == dimx) && (Q.rows() == dimx), "KalmanFull: Q is not compatible");
539        bdm_assert ( (R.cols() == dimy) && (R.rows() == dimy), "KalmanFull: R is not compatible");
540}
541
542}
543#endif // KF_H
544
Note: See TracBrowser for help on using the browser.