root/library/bdm/estim/kalman.h @ 679

Revision 679, 14.1 kB (checked in by smidl, 15 years ago)

Major changes in BM -- OK is only test suite and tests/tutorial -- the rest is broken!!!

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! Matrix A
37                mat A;
38                //! Matrix B
39                mat B;
40                //! Matrix C
41                mat C;
42                //! Matrix D
43                mat D;
44                //! Matrix Q in square-root form
45                sq_T Q;
46                //! Matrix R in square-root form
47                sq_T R;
48        public:
49                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
50                //!copy constructor
51                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
52                //! set all matrix parameters
53                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
54                //! validation
55                void validate();
56                //! not virtual in this case
57                void from_setting (const Setting &set) {
58                        UI::get (A, set, "A", UI::compulsory);
59                        UI::get (B, set, "B", UI::compulsory);
60                        UI::get (C, set, "C", UI::compulsory);
61                        UI::get (D, set, "D", UI::compulsory);
62                        mat Qtm, Rtm;
63                        if(!UI::get(Qtm, set, "Q", UI::optional)){
64                                vec dq;
65                                UI::get(dq, set, "dQ", UI::compulsory);
66                                Qtm=diag(dq);
67                        }
68                        if(!UI::get(Rtm, set, "R", UI::optional)){
69                                vec dr;
70                                UI::get(dr, set, "dQ", UI::compulsory);
71                                Rtm=diag(dr);
72                        }
73                        R=Rtm; // automatic conversion
74                        Q=Qtm; 
75                       
76                        validate();
77                }               
78                //! access function
79                const mat& _A() const {return A;}
80                //! access function
81                const mat& _B()const {return B;}
82                //! access function
83                const mat& _C()const {return C;}
84                //! access function
85                const mat& _D()const {return D;}
86                //! access function
87                const sq_T& _Q()const {return Q;}
88                //! access function
89                const sq_T& _R()const {return R;}
90};
91
92//! Common abstract base for Kalman filters
93template<class sq_T>
94class Kalman: public BM, public StateSpace<sq_T>
95{
96        protected:
97                //! id of output
98                RV yrv;
99                //! Kalman gain
100                mat  _K;
101                //!posterior
102                enorm<sq_T> est;
103                //!marginal on data f(y|y)
104                enorm<sq_T>  fy;
105        public:
106                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
107                //! Copy constructor
108                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
109                //!set statistics of the posterior
110                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
111                //!set statistics of the posterior
112                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
113                //! return correctly typed posterior (covariant return)
114                const enorm<sq_T>& posterior() const {return est;}
115                //! load basic elements of Kalman from structure
116                void from_setting (const Setting &set) {
117                        StateSpace<sq_T>::from_setting(set);
118                                               
119                        mat P0; vec mu0;
120                        UI::get(mu0, set, "mu0", UI::optional);
121                        UI::get(P0, set,  "P0", UI::optional);
122                        set_statistics(mu0,P0);
123                        // Initial values
124                        UI::get (yrv, set, "yrv", UI::optional);
125                        UI::get (rvc, set, "urv", UI::optional);
126                        set_yrv(concat(yrv,rvc));
127                       
128                        validate();
129                }
130                //! validate object
131                void validate() {
132                        StateSpace<sq_T>::validate();
133                        dimy = this->C.rows();
134                        dimc = this->B.cols();
135                        set_dim(this->A.rows());
136                       
137                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
138                }
139};
140/*!
141* \brief Basic Kalman filter with full matrices
142*/
143
144class KalmanFull : public Kalman<fsqmat>
145{
146        public:
147                //! For EKFfull;
148                KalmanFull() :Kalman<fsqmat>(){};
149                //! Here dt = [yt;ut] of appropriate dimensions
150                void bayes (const vec &yt, const vec &cond=empty_vec);
151                BM* _copy_() const {
152                        KalmanFull* K = new KalmanFull;
153                        K->set_parameters (A, B, C, D, Q, R);
154                        K->set_statistics (est._mu(), est._R());
155                        return K;
156                }
157};
158UIREGISTER(KalmanFull);
159
160
161/*! \brief Kalman filter in square root form
162
163Trivial example:
164\include kalman_simple.cpp
165
166Complete constructor:
167*/
168class KalmanCh : public Kalman<chmat>
169{
170        protected:
171                //! @{ \name Internal storage - needs initialize()
172                //! pre array (triangular matrix)
173                mat preA;
174                //! post array (triangular matrix)
175                mat postA;
176                //!@}
177        public:
178                //! copy constructor
179                BM* _copy_() const {
180                        KalmanCh* K = new KalmanCh;
181                        K->set_parameters (A, B, C, D, Q, R);
182                        K->set_statistics (est._mu(), est._R());
183                        return K;
184                }
185                //! set parameters for adapt from Kalman
186                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
187                //! initialize internal parametetrs
188                void initialize();
189
190                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
191
192                The following equality hold::\f[
193                \left[\begin{array}{cc}
194                R^{0.5}\\
195                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
196                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
197                R_{y}^{0.5} & KA'\\
198                & P_{t+1|t}^{0.5}\\
199                \\\end{array}\right]\f]
200
201                Thus this object evaluates only predictors! Not filtering densities.
202                */
203                void bayes (const vec &yt, const vec &cond=empty_vec);
204
205                void from_setting(const Setting &set){
206                        Kalman<chmat>::from_setting(set);
207                        validate();
208                }
209                void validate() {
210                        Kalman<chmat>::validate();
211                        initialize();
212                }
213};
214UIREGISTER(KalmanCh);
215
216/*!
217\brief Extended Kalman Filter in full matrices
218
219An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
220*/
221class EKFfull : public KalmanFull
222{
223        protected:
224                //! Internal Model f(x,u)
225                shared_ptr<diffbifn> pfxu;
226
227                //! Observation Model h(x,u)
228                shared_ptr<diffbifn> phxu;
229
230        public:
231                //! Default constructor
232                EKFfull ();
233
234                //! Set nonlinear functions for mean values and covariance matrices.
235                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
236
237                //! Here dt = [yt;ut] of appropriate dimensions
238                void bayes (const vec &yt, const vec &cond=empty_vec);
239                //! set estimates
240                void set_statistics (const vec &mu0, const mat &P0) {
241                        est.set_parameters (mu0, P0);
242                };
243                //! access function
244                const mat _R() {
245                        return est._R().to_mat();
246                }
247                void from_setting (const Setting &set) {
248                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
249                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
250                       
251                        //statistics
252                        int dim = IM->dimension();
253                        vec mu0;
254                        if ( !UI::get ( mu0, set, "mu0" ) )
255                                mu0 = zeros ( dim );
256                       
257                        mat P0;
258                        vec dP0;
259                        if ( UI::get ( dP0, set, "dP0" ) )
260                                P0 = diag ( dP0 );
261                        else if ( !UI::get ( P0, set, "P0" ) )
262                                P0 = eye ( dim );
263                       
264                        set_statistics ( mu0, P0 );
265                       
266                        //parameters
267                        vec dQ, dR;
268                        UI::get ( dQ, set, "dQ", UI::compulsory );
269                        UI::get ( dR, set, "dR", UI::compulsory );
270                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
271                       
272                        //connect
273                        shared_ptr<RV> drv = UI::build<RV> ( set, "drv", UI::compulsory );
274                        set_yrv ( *drv );
275                        shared_ptr<RV> rv = UI::build<RV> ( set, "rv", UI::compulsory );
276                        set_rv ( *rv );
277                       
278                        string options;
279                        if ( UI::get ( options, set, "options" ) )
280                                set_options ( options );
281//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
282//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
283//                     
284//                      mat R0;
285//                      UI::get(R0, set, "R",UI::compulsory);
286//                      mat Q0;
287//                      UI::get(Q0, set, "Q",UI::compulsory);
288//                     
289//                     
290//                      mat P0; vec mu0;
291//                      UI::get(mu0, set, "mu0", UI::optional);
292//                      UI::get(P0, set,  "P0", UI::optional);
293//                      set_statistics(mu0,P0);
294//                      // Initial values
295//                      UI::get (yrv, set, "yrv", UI::optional);
296//                      UI::get (urv, set, "urv", UI::optional);
297//                      set_drv(concat(yrv,urv));
298//
299//                      // setup StateSpace
300//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
301//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
302//                     
303                        validate();
304                }
305                void validate() {
306                        // check stats and IM and OM
307                }
308};
309UIREGISTER(EKFfull);
310
311
312/*!
313\brief Extended Kalman Filter in Square root
314
315An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
316*/
317
318class EKFCh : public KalmanCh
319{
320        protected:
321                //! Internal Model f(x,u)
322                shared_ptr<diffbifn> pfxu;
323
324                //! Observation Model h(x,u)
325                shared_ptr<diffbifn> phxu;
326        public:
327                //! copy constructor duplicated - calls different set_parameters
328                BM* _copy_() const {
329                        EKFCh* E = new EKFCh;
330                        E->set_parameters (pfxu, phxu, Q, R);
331                        E->set_statistics (est._mu(), est._R());
332                        return E;
333                }
334                //! Set nonlinear functions for mean values and covariance matrices.
335                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
336
337                //! Here dt = [yt;ut] of appropriate dimensions
338                void bayes (const vec &yt, const vec &cond=empty_vec);
339
340                void from_setting (const Setting &set);
341
342                // TODO dodelat void to_setting( Setting &set ) const;
343
344};
345
346UIREGISTER (EKFCh);
347SHAREDPTR (EKFCh);
348
349
350//////// INstance
351
352/*! \brief (Switching) Multiple Model
353The model runs several models in parallel and evaluates thier weights (fittness).
354
355The statistics of the resulting density are merged using (geometric?) combination.
356
357The next step is performed with the new statistics for all models.
358*/
359class MultiModel: public BM
360{
361        protected:
362                //! List of models between which we switch
363                Array<EKFCh*> Models;
364                //! vector of model weights
365                vec w;
366                //! cache of model lls
367                vec _lls;
368                //! type of switching policy [1=maximum,2=...]
369                int policy;
370                //! internal statistics
371                enorm<chmat> est;
372        public:
373                //! set internal parameters
374                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
375                        Models = A;//TODO: test if evalll is set
376                        w.set_length (A.length());
377                        _lls.set_length (A.length());
378                        policy = pol0;
379
380                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
381                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
382                }
383                void bayes (const vec &yt, const vec &cond=empty_vec) {
384                        int n = Models.length();
385                        int i;
386                        for (i = 0; i < n; i++) {
387                                Models (i)->bayes (yt);
388                                _lls (i) = Models (i)->_ll();
389                        }
390                        double mlls = max (_lls);
391                        w = exp (_lls - mlls);
392                        w /= sum (w);   //normalization
393                        //set statistics
394                        switch (policy) {
395                                case 1: {
396                                        int mi = max_index (w);
397                                        const enorm<chmat> &st = Models (mi)->posterior() ;
398                                        est.set_parameters (st.mean(), st._R());
399                                }
400                                break;
401                                default:
402                                        bdm_error ("unknown policy");
403                        }
404                        // copy result to all models
405                        for (i = 0; i < n; i++) {
406                                Models (i)->set_statistics (est.mean(), est._R());
407                        }
408                }
409                //! return correctly typed posterior (covariant return)
410                const enorm<chmat>& posterior() const {
411                        return est;
412                }
413
414                void from_setting (const Setting &set);
415
416                // TODO dodelat void to_setting( Setting &set ) const;
417
418};
419
420UIREGISTER (MultiModel);
421SHAREDPTR (MultiModel);
422
423//! conversion of outer ARX model (mlnorm) to state space model
424/*!
425The model is constructed as:
426\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
427For example, for:
428Using Frobenius form, see [].
429
430For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
431*/
432//template<class sq_T>
433class StateCanonical: public StateSpace<fsqmat>{
434        protected:
435                //! remember connection from theta ->A
436                datalink_part th2A;
437                //! remember connection from theta ->C
438                datalink_part th2C;
439                //! remember connection from theta ->D
440                datalink_part th2D;
441                //!cached first row of A
442                vec A1row;
443                //!cached first row of C
444                vec C1row;
445                //!cached first row of D
446                vec D1row;
447               
448        public:
449                //! set up this object to match given mlnorm
450                void connect_mlnorm(const mlnorm<fsqmat> &ml){
451                        //get ids of yrv                               
452                        const RV &yrv = ml._rv();
453                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
454                        RV rgr0 = ml._rvc().remove_time();
455                        RV urv = rgr0.subt(yrv); 
456                       
457                        //We can do only 1d now... :(
458                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
459
460                        // create names for
461                        RV xrv; //empty
462                        RV Crv; //empty
463                        int td=ml._rvc().mint();
464                        // assuming strictly proper function!!!
465                        for (int t=-1;t>=td;t--){
466                                xrv.add(yrv.copy_t(t));
467                                Crv.add(urv.copy_t(t));
468                        }
469                                               
470                        // get mapp
471                        th2A.set_connection(xrv, ml._rvc());
472                        th2C.set_connection(Crv, ml._rvc());
473                        th2D.set_connection(urv, ml._rvc());
474
475                        //set matrix sizes
476                        this->A=zeros(xrv._dsize(),xrv._dsize());
477                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
478                                this->B=zeros(xrv._dsize(),1);
479                                this->B(0) = 1.0;
480                                this->C=zeros(1,xrv._dsize());
481                                this->D=zeros(1,urv._dsize());
482                                this->Q = zeros(xrv._dsize(),xrv._dsize());
483                        // R is set by update
484                       
485                        //set cache
486                        this->A1row = zeros(xrv._dsize());
487                        this->C1row = zeros(xrv._dsize());
488                        this->D1row = zeros(urv._dsize());
489                       
490                        update_from(ml);
491                        validate();
492                };
493                //! fast function to update parameters from ml - not checked for compatibility!!
494                void update_from(const mlnorm<fsqmat> &ml){
495                       
496                        vec theta = ml._A().get_row(0); // this
497                       
498                        th2A.filldown(theta,A1row);
499                        th2C.filldown(theta,C1row);
500                        th2D.filldown(theta,D1row);
501
502                        R = ml._R();
503
504                        A.set_row(0,A1row);
505                        C.set_row(0,C1row+D1row*A1row);
506                        D.set_row(0,D1row);
507                       
508                }
509};
510
511/////////// INSTANTIATION
512
513template<class sq_T>
514void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
515{
516
517        A = A0;
518        B = B0;
519        C = C0;
520        D = D0;
521        R = R0;
522        Q = Q0;
523        validate();
524}
525
526template<class sq_T>
527void StateSpace<sq_T>::validate(){
528        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
529        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
530        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
531        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
532        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
533        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
534}
535
536}
537#endif // KF_H
538
Note: See TracBrowser for help on using the browser.