root/library/bdm/estim/kalman.h @ 681

Revision 681, 14.1 kB (checked in by smidl, 15 years ago)

corrections of Kalman

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! Matrix A
37                mat A;
38                //! Matrix B
39                mat B;
40                //! Matrix C
41                mat C;
42                //! Matrix D
43                mat D;
44                //! Matrix Q in square-root form
45                sq_T Q;
46                //! Matrix R in square-root form
47                sq_T R;
48        public:
49                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
50                //!copy constructor
51                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
52                //! set all matrix parameters
53                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
54                //! validation
55                void validate();
56                //! not virtual in this case
57                void from_setting (const Setting &set) {
58                        UI::get (A, set, "A", UI::compulsory);
59                        UI::get (B, set, "B", UI::compulsory);
60                        UI::get (C, set, "C", UI::compulsory);
61                        UI::get (D, set, "D", UI::compulsory);
62                        mat Qtm, Rtm;
63                        if(!UI::get(Qtm, set, "Q", UI::optional)){
64                                vec dq;
65                                UI::get(dq, set, "dQ", UI::compulsory);
66                                Qtm=diag(dq);
67                        }
68                        if(!UI::get(Rtm, set, "R", UI::optional)){
69                                vec dr;
70                                UI::get(dr, set, "dQ", UI::compulsory);
71                                Rtm=diag(dr);
72                        }
73                        R=Rtm; // automatic conversion
74                        Q=Qtm; 
75                       
76                        validate();
77                }               
78                //! access function
79                const mat& _A() const {return A;}
80                //! access function
81                const mat& _B()const {return B;}
82                //! access function
83                const mat& _C()const {return C;}
84                //! access function
85                const mat& _D()const {return D;}
86                //! access function
87                const sq_T& _Q()const {return Q;}
88                //! access function
89                const sq_T& _R()const {return R;}
90};
91
92//! Common abstract base for Kalman filters
93template<class sq_T>
94class Kalman: public BM, public StateSpace<sq_T>
95{
96        protected:
97                //! id of output
98                RV yrv;
99                //! Kalman gain
100                mat  _K;
101                //!posterior
102                enorm<sq_T> est;
103                //!marginal on data f(y|y)
104                enorm<sq_T>  fy;
105        public:
106                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
107                //! Copy constructor
108                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
109                //!set statistics of the posterior
110                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
111                //!set statistics of the posterior
112                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
113                //! return correctly typed posterior (covariant return)
114                const enorm<sq_T>& posterior() const {return est;}
115                //! load basic elements of Kalman from structure
116                void from_setting (const Setting &set) {
117                        StateSpace<sq_T>::from_setting(set);
118                                               
119                        mat P0; vec mu0;
120                        UI::get(mu0, set, "mu0", UI::optional);
121                        UI::get(P0, set,  "P0", UI::optional);
122                        set_statistics(mu0,P0);
123                        // Initial values
124                        UI::get (yrv, set, "yrv", UI::optional);
125                        UI::get (rvc, set, "urv", UI::optional);
126                        set_yrv(concat(yrv,rvc));
127                       
128                        validate();
129                }
130                //! validate object
131                void validate() {
132                        StateSpace<sq_T>::validate();
133                        dimy = this->C.rows();
134                        dimc = this->B.cols();
135                        set_dim(this->A.rows());
136                       
137                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
138                }
139};
140/*!
141* \brief Basic Kalman filter with full matrices
142*/
143
144class KalmanFull : public Kalman<fsqmat>
145{
146        public:
147                //! For EKFfull;
148                KalmanFull() :Kalman<fsqmat>(){};
149                //! Here dt = [yt;ut] of appropriate dimensions
150                void bayes (const vec &yt, const vec &cond=empty_vec);
151                BM* _copy_() const {
152                        KalmanFull* K = new KalmanFull;
153                        K->set_parameters (A, B, C, D, Q, R);
154                        K->set_statistics (est._mu(), est._R());
155                        return K;
156                }
157};
158UIREGISTER(KalmanFull);
159
160
161/*! \brief Kalman filter in square root form
162
163Trivial example:
164\include kalman_simple.cpp
165
166Complete constructor:
167*/
168class KalmanCh : public Kalman<chmat>
169{
170        protected:
171                //! @{ \name Internal storage - needs initialize()
172                //! pre array (triangular matrix)
173                mat preA;
174                //! post array (triangular matrix)
175                mat postA;
176                //!@}
177        public:
178                //! copy constructor
179                BM* _copy_() const {
180                        KalmanCh* K = new KalmanCh;
181                        K->set_parameters (A, B, C, D, Q, R);
182                        K->set_statistics (est._mu(), est._R());
183                        K->validate();
184                        return K;
185                }
186                //! set parameters for adapt from Kalman
187                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
188                //! initialize internal parametetrs
189                void initialize();
190
191                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
192
193                The following equality hold::\f[
194                \left[\begin{array}{cc}
195                R^{0.5}\\
196                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
197                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
198                R_{y}^{0.5} & KA'\\
199                & P_{t+1|t}^{0.5}\\
200                \\\end{array}\right]\f]
201
202                Thus this object evaluates only predictors! Not filtering densities.
203                */
204                void bayes (const vec &yt, const vec &cond=empty_vec);
205
206                void from_setting(const Setting &set){
207                        Kalman<chmat>::from_setting(set);
208                        validate();
209                }
210                void validate() {
211                        Kalman<chmat>::validate();
212                        initialize();
213                }
214};
215UIREGISTER(KalmanCh);
216
217/*!
218\brief Extended Kalman Filter in full matrices
219
220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
227
228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
230
231        public:
232                //! Default constructor
233                EKFfull ();
234
235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
237
238                //! Here dt = [yt;ut] of appropriate dimensions
239                void bayes (const vec &yt, const vec &cond=empty_vec);
240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
242                        est.set_parameters (mu0, P0);
243                };
244                //! access function
245                const mat _R() {
246                        return est._R().to_mat();
247                }
248                void from_setting (const Setting &set) {
249                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
250                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
251                       
252                        //statistics
253                        int dim = IM->dimension();
254                        vec mu0;
255                        if ( !UI::get ( mu0, set, "mu0" ) )
256                                mu0 = zeros ( dim );
257                       
258                        mat P0;
259                        vec dP0;
260                        if ( UI::get ( dP0, set, "dP0" ) )
261                                P0 = diag ( dP0 );
262                        else if ( !UI::get ( P0, set, "P0" ) )
263                                P0 = eye ( dim );
264                       
265                        set_statistics ( mu0, P0 );
266                       
267                        //parameters
268                        vec dQ, dR;
269                        UI::get ( dQ, set, "dQ", UI::compulsory );
270                        UI::get ( dR, set, "dR", UI::compulsory );
271                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
272                       
273                        //connect
274                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
275                        set_yrv ( *drv );
276                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
277                        set_rv ( *rv );
278                       
279                        string options;
280                        if ( UI::get ( options, set, "options" ) )
281                                set_options ( options );
282//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
283//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
284//                     
285//                      mat R0;
286//                      UI::get(R0, set, "R",UI::compulsory);
287//                      mat Q0;
288//                      UI::get(Q0, set, "Q",UI::compulsory);
289//                     
290//                     
291//                      mat P0; vec mu0;
292//                      UI::get(mu0, set, "mu0", UI::optional);
293//                      UI::get(P0, set,  "P0", UI::optional);
294//                      set_statistics(mu0,P0);
295//                      // Initial values
296//                      UI::get (yrv, set, "yrv", UI::optional);
297//                      UI::get (urv, set, "urv", UI::optional);
298//                      set_drv(concat(yrv,urv));
299//
300//                      // setup StateSpace
301//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
302//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
303//                     
304                        validate();
305                }
306                void validate() {
307                        // check stats and IM and OM
308                }
309};
310UIREGISTER(EKFfull);
311
312
313/*!
314\brief Extended Kalman Filter in Square root
315
316An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
317*/
318
319class EKFCh : public KalmanCh
320{
321        protected:
322                //! Internal Model f(x,u)
323                shared_ptr<diffbifn> pfxu;
324
325                //! Observation Model h(x,u)
326                shared_ptr<diffbifn> phxu;
327        public:
328                //! copy constructor duplicated - calls different set_parameters
329                BM* _copy_() const {
330                        EKFCh* E = new EKFCh;
331                        E->set_parameters (pfxu, phxu, Q, R);
332                        E->set_statistics (est._mu(), est._R());
333                        return E;
334                }
335                //! Set nonlinear functions for mean values and covariance matrices.
336                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
337
338                //! Here dt = [yt;ut] of appropriate dimensions
339                void bayes (const vec &yt, const vec &cond=empty_vec);
340
341                void from_setting (const Setting &set);
342
343                void validate(){};
344                // TODO dodelat void to_setting( Setting &set ) const;
345
346};
347
348UIREGISTER (EKFCh);
349SHAREDPTR (EKFCh);
350
351
352//////// INstance
353
354/*! \brief (Switching) Multiple Model
355The model runs several models in parallel and evaluates thier weights (fittness).
356
357The statistics of the resulting density are merged using (geometric?) combination.
358
359The next step is performed with the new statistics for all models.
360*/
361class MultiModel: public BM
362{
363        protected:
364                //! List of models between which we switch
365                Array<EKFCh*> Models;
366                //! vector of model weights
367                vec w;
368                //! cache of model lls
369                vec _lls;
370                //! type of switching policy [1=maximum,2=...]
371                int policy;
372                //! internal statistics
373                enorm<chmat> est;
374        public:
375                //! set internal parameters
376                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
377                        Models = A;//TODO: test if evalll is set
378                        w.set_length (A.length());
379                        _lls.set_length (A.length());
380                        policy = pol0;
381
382                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
383                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
384                }
385                void bayes (const vec &yt, const vec &cond=empty_vec) {
386                        int n = Models.length();
387                        int i;
388                        for (i = 0; i < n; i++) {
389                                Models (i)->bayes (yt);
390                                _lls (i) = Models (i)->_ll();
391                        }
392                        double mlls = max (_lls);
393                        w = exp (_lls - mlls);
394                        w /= sum (w);   //normalization
395                        //set statistics
396                        switch (policy) {
397                                case 1: {
398                                        int mi = max_index (w);
399                                        const enorm<chmat> &st = Models (mi)->posterior() ;
400                                        est.set_parameters (st.mean(), st._R());
401                                }
402                                break;
403                                default:
404                                        bdm_error ("unknown policy");
405                        }
406                        // copy result to all models
407                        for (i = 0; i < n; i++) {
408                                Models (i)->set_statistics (est.mean(), est._R());
409                        }
410                }
411                //! return correctly typed posterior (covariant return)
412                const enorm<chmat>& posterior() const {
413                        return est;
414                }
415
416                void from_setting (const Setting &set);
417
418                // TODO dodelat void to_setting( Setting &set ) const;
419
420};
421
422UIREGISTER (MultiModel);
423SHAREDPTR (MultiModel);
424
425//! conversion of outer ARX model (mlnorm) to state space model
426/*!
427The model is constructed as:
428\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
429For example, for:
430Using Frobenius form, see [].
431
432For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
433*/
434//template<class sq_T>
435class StateCanonical: public StateSpace<fsqmat>{
436        protected:
437                //! remember connection from theta ->A
438                datalink_part th2A;
439                //! remember connection from theta ->C
440                datalink_part th2C;
441                //! remember connection from theta ->D
442                datalink_part th2D;
443                //!cached first row of A
444                vec A1row;
445                //!cached first row of C
446                vec C1row;
447                //!cached first row of D
448                vec D1row;
449               
450        public:
451                //! set up this object to match given mlnorm
452                void connect_mlnorm(const mlnorm<fsqmat> &ml){
453                        //get ids of yrv                               
454                        const RV &yrv = ml._rv();
455                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
456                        RV rgr0 = ml._rvc().remove_time();
457                        RV urv = rgr0.subt(yrv); 
458                       
459                        //We can do only 1d now... :(
460                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
461
462                        // create names for
463                        RV xrv; //empty
464                        RV Crv; //empty
465                        int td=ml._rvc().mint();
466                        // assuming strictly proper function!!!
467                        for (int t=-1;t>=td;t--){
468                                xrv.add(yrv.copy_t(t));
469                                Crv.add(urv.copy_t(t));
470                        }
471                                               
472                        // get mapp
473                        th2A.set_connection(xrv, ml._rvc());
474                        th2C.set_connection(Crv, ml._rvc());
475                        th2D.set_connection(urv, ml._rvc());
476
477                        //set matrix sizes
478                        this->A=zeros(xrv._dsize(),xrv._dsize());
479                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
480                                this->B=zeros(xrv._dsize(),1);
481                                this->B(0) = 1.0;
482                                this->C=zeros(1,xrv._dsize());
483                                this->D=zeros(1,urv._dsize());
484                                this->Q = zeros(xrv._dsize(),xrv._dsize());
485                        // R is set by update
486                       
487                        //set cache
488                        this->A1row = zeros(xrv._dsize());
489                        this->C1row = zeros(xrv._dsize());
490                        this->D1row = zeros(urv._dsize());
491                       
492                        update_from(ml);
493                        validate();
494                };
495                //! fast function to update parameters from ml - not checked for compatibility!!
496                void update_from(const mlnorm<fsqmat> &ml){
497                       
498                        vec theta = ml._A().get_row(0); // this
499                       
500                        th2A.filldown(theta,A1row);
501                        th2C.filldown(theta,C1row);
502                        th2D.filldown(theta,D1row);
503
504                        R = ml._R();
505
506                        A.set_row(0,A1row);
507                        C.set_row(0,C1row+D1row*A1row);
508                        D.set_row(0,D1row);
509                       
510                }
511};
512
513/////////// INSTANTIATION
514
515template<class sq_T>
516void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
517{
518
519        A = A0;
520        B = B0;
521        C = C0;
522        D = D0;
523        R = R0;
524        Q = Q0;
525        validate();
526}
527
528template<class sq_T>
529void StateSpace<sq_T>::validate(){
530        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
531        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
532        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
533        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
534        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
535        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
536}
537
538}
539#endif // KF_H
540
Note: See TracBrowser for help on using the browser.