root/library/bdm/estim/kalman.h @ 686

Revision 686, 14.0 kB (checked in by smidl, 15 years ago)

pmsm using new syntax for bayes

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! Matrix A
37                mat A;
38                //! Matrix B
39                mat B;
40                //! Matrix C
41                mat C;
42                //! Matrix D
43                mat D;
44                //! Matrix Q in square-root form
45                sq_T Q;
46                //! Matrix R in square-root form
47                sq_T R;
48        public:
49                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
50                //!copy constructor
51                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
52                //! set all matrix parameters
53                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
54                //! validation
55                void validate();
56                //! not virtual in this case
57                void from_setting (const Setting &set) {
58                        UI::get (A, set, "A", UI::compulsory);
59                        UI::get (B, set, "B", UI::compulsory);
60                        UI::get (C, set, "C", UI::compulsory);
61                        UI::get (D, set, "D", UI::compulsory);
62                        mat Qtm, Rtm;
63                        if(!UI::get(Qtm, set, "Q", UI::optional)){
64                                vec dq;
65                                UI::get(dq, set, "dQ", UI::compulsory);
66                                Qtm=diag(dq);
67                        }
68                        if(!UI::get(Rtm, set, "R", UI::optional)){
69                                vec dr;
70                                UI::get(dr, set, "dQ", UI::compulsory);
71                                Rtm=diag(dr);
72                        }
73                        R=Rtm; // automatic conversion
74                        Q=Qtm; 
75                       
76                        validate();
77                }               
78                //! access function
79                const mat& _A() const {return A;}
80                //! access function
81                const mat& _B()const {return B;}
82                //! access function
83                const mat& _C()const {return C;}
84                //! access function
85                const mat& _D()const {return D;}
86                //! access function
87                const sq_T& _Q()const {return Q;}
88                //! access function
89                const sq_T& _R()const {return R;}
90};
91
92//! Common abstract base for Kalman filters
93template<class sq_T>
94class Kalman: public BM, public StateSpace<sq_T>
95{
96        protected:
97                //! id of output
98                RV yrv;
99                //! Kalman gain
100                mat  _K;
101                //!posterior
102                enorm<sq_T> est;
103                //!marginal on data f(y|y)
104                enorm<sq_T>  fy;
105        public:
106                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
107                //! Copy constructor
108                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
109                //!set statistics of the posterior
110                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
111                //!set statistics of the posterior
112                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
113                //! return correctly typed posterior (covariant return)
114                const enorm<sq_T>& posterior() const {return est;}
115                //! load basic elements of Kalman from structure
116                void from_setting (const Setting &set) {
117                        StateSpace<sq_T>::from_setting(set);
118                                               
119                        mat P0; vec mu0;
120                        UI::get(mu0, set, "mu0", UI::optional);
121                        UI::get(P0, set,  "P0", UI::optional);
122                        set_statistics(mu0,P0);
123                        // Initial values
124                        UI::get (yrv, set, "yrv", UI::optional);
125                        UI::get (rvc, set, "urv", UI::optional);
126                        set_yrv(concat(yrv,rvc));
127                       
128                        validate();
129                }
130                //! validate object
131                void validate() {
132                        StateSpace<sq_T>::validate();
133                        dimy = this->C.rows();
134                        dimc = this->B.cols();
135                        set_dim(this->A.rows());
136                       
137                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
138                }
139};
140/*!
141* \brief Basic Kalman filter with full matrices
142*/
143
144class KalmanFull : public Kalman<fsqmat>
145{
146        public:
147                //! For EKFfull;
148                KalmanFull() :Kalman<fsqmat>(){};
149                //! Here dt = [yt;ut] of appropriate dimensions
150                void bayes (const vec &yt, const vec &cond=empty_vec);
151                BM* _copy_() const {
152                        KalmanFull* K = new KalmanFull;
153                        K->set_parameters (A, B, C, D, Q, R);
154                        K->set_statistics (est._mu(), est._R());
155                        return K;
156                }
157};
158UIREGISTER(KalmanFull);
159
160
161/*! \brief Kalman filter in square root form
162
163Trivial example:
164\include kalman_simple.cpp
165
166Complete constructor:
167*/
168class KalmanCh : public Kalman<chmat>
169{
170        protected:
171                //! @{ \name Internal storage - needs initialize()
172                //! pre array (triangular matrix)
173                mat preA;
174                //! post array (triangular matrix)
175                mat postA;
176                //!@}
177        public:
178                //! copy constructor
179                BM* _copy_() const {
180                        KalmanCh* K = new KalmanCh;
181                        K->set_parameters (A, B, C, D, Q, R);
182                        K->set_statistics (est._mu(), est._R());
183                        K->validate();
184                        return K;
185                }
186                //! set parameters for adapt from Kalman
187                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
188                //! initialize internal parametetrs
189                void initialize();
190
191                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
192
193                The following equality hold::\f[
194                \left[\begin{array}{cc}
195                R^{0.5}\\
196                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
197                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
198                R_{y}^{0.5} & KA'\\
199                & P_{t+1|t}^{0.5}\\
200                \\\end{array}\right]\f]
201
202                Thus this object evaluates only predictors! Not filtering densities.
203                */
204                void bayes (const vec &yt, const vec &cond=empty_vec);
205
206                void from_setting(const Setting &set){
207                        Kalman<chmat>::from_setting(set);
208                        validate();
209                }
210                void validate() {
211                        Kalman<chmat>::validate();
212                        initialize();
213                }
214};
215UIREGISTER(KalmanCh);
216
217/*!
218\brief Extended Kalman Filter in full matrices
219
220An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
221*/
222class EKFfull : public KalmanFull
223{
224        protected:
225                //! Internal Model f(x,u)
226                shared_ptr<diffbifn> pfxu;
227
228                //! Observation Model h(x,u)
229                shared_ptr<diffbifn> phxu;
230
231        public:
232                //! Default constructor
233                EKFfull ();
234
235                //! Set nonlinear functions for mean values and covariance matrices.
236                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
237
238                //! Here dt = [yt;ut] of appropriate dimensions
239                void bayes (const vec &yt, const vec &cond=empty_vec);
240                //! set estimates
241                void set_statistics (const vec &mu0, const mat &P0) {
242                        est.set_parameters (mu0, P0);
243                };
244                //! access function
245                const mat _R() {
246                        return est._R().to_mat();
247                }
248                void from_setting (const Setting &set) {                       
249                        BM::from_setting(set);
250                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
251                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
252                       
253                        //statistics
254                        int dim = IM->dimension();
255                        vec mu0;
256                        if ( !UI::get ( mu0, set, "mu0" ) )
257                                mu0 = zeros ( dim );
258                       
259                        mat P0;
260                        vec dP0;
261                        if ( UI::get ( dP0, set, "dP0" ) )
262                                P0 = diag ( dP0 );
263                        else if ( !UI::get ( P0, set, "P0" ) )
264                                P0 = eye ( dim );
265                       
266                        set_statistics ( mu0, P0 );
267                       
268                        //parameters
269                        vec dQ, dR;
270                        UI::get ( dQ, set, "dQ", UI::compulsory );
271                        UI::get ( dR, set, "dR", UI::compulsory );
272                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
273                                               
274                        string options;
275                        if ( UI::get ( options, set, "options" ) )
276                                set_options ( options );
277//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
278//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
279//                     
280//                      mat R0;
281//                      UI::get(R0, set, "R",UI::compulsory);
282//                      mat Q0;
283//                      UI::get(Q0, set, "Q",UI::compulsory);
284//                     
285//                     
286//                      mat P0; vec mu0;
287//                      UI::get(mu0, set, "mu0", UI::optional);
288//                      UI::get(P0, set,  "P0", UI::optional);
289//                      set_statistics(mu0,P0);
290//                      // Initial values
291//                      UI::get (yrv, set, "yrv", UI::optional);
292//                      UI::get (urv, set, "urv", UI::optional);
293//                      set_drv(concat(yrv,urv));
294//
295//                      // setup StateSpace
296//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
297//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
298//                     
299                        validate();
300                }
301                void validate() {
302                        // check stats and IM and OM
303                }
304};
305UIREGISTER(EKFfull);
306
307
308/*!
309\brief Extended Kalman Filter in Square root
310
311An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
312*/
313
314class EKFCh : public KalmanCh
315{
316        protected:
317                //! Internal Model f(x,u)
318                shared_ptr<diffbifn> pfxu;
319
320                //! Observation Model h(x,u)
321                shared_ptr<diffbifn> phxu;
322        public:
323                //! copy constructor duplicated - calls different set_parameters
324                BM* _copy_() const {
325                        EKFCh* E = new EKFCh;
326                        E->set_parameters (pfxu, phxu, Q, R);
327                        E->set_statistics (est._mu(), est._R());
328                        return E;
329                }
330                //! Set nonlinear functions for mean values and covariance matrices.
331                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
332
333                //! Here dt = [yt;ut] of appropriate dimensions
334                void bayes (const vec &yt, const vec &cond=empty_vec);
335
336                void from_setting (const Setting &set);
337
338                void validate(){};
339                // TODO dodelat void to_setting( Setting &set ) const;
340
341};
342
343UIREGISTER (EKFCh);
344SHAREDPTR (EKFCh);
345
346
347//////// INstance
348
349/*! \brief (Switching) Multiple Model
350The model runs several models in parallel and evaluates thier weights (fittness).
351
352The statistics of the resulting density are merged using (geometric?) combination.
353
354The next step is performed with the new statistics for all models.
355*/
356class MultiModel: public BM
357{
358        protected:
359                //! List of models between which we switch
360                Array<EKFCh*> Models;
361                //! vector of model weights
362                vec w;
363                //! cache of model lls
364                vec _lls;
365                //! type of switching policy [1=maximum,2=...]
366                int policy;
367                //! internal statistics
368                enorm<chmat> est;
369        public:
370                //! set internal parameters
371                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
372                        Models = A;//TODO: test if evalll is set
373                        w.set_length (A.length());
374                        _lls.set_length (A.length());
375                        policy = pol0;
376
377                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
378                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
379                }
380                void bayes (const vec &yt, const vec &cond=empty_vec) {
381                        int n = Models.length();
382                        int i;
383                        for (i = 0; i < n; i++) {
384                                Models (i)->bayes (yt);
385                                _lls (i) = Models (i)->_ll();
386                        }
387                        double mlls = max (_lls);
388                        w = exp (_lls - mlls);
389                        w /= sum (w);   //normalization
390                        //set statistics
391                        switch (policy) {
392                                case 1: {
393                                        int mi = max_index (w);
394                                        const enorm<chmat> &st = Models (mi)->posterior() ;
395                                        est.set_parameters (st.mean(), st._R());
396                                }
397                                break;
398                                default:
399                                        bdm_error ("unknown policy");
400                        }
401                        // copy result to all models
402                        for (i = 0; i < n; i++) {
403                                Models (i)->set_statistics (est.mean(), est._R());
404                        }
405                }
406                //! return correctly typed posterior (covariant return)
407                const enorm<chmat>& posterior() const {
408                        return est;
409                }
410
411                void from_setting (const Setting &set);
412
413                // TODO dodelat void to_setting( Setting &set ) const;
414
415};
416
417UIREGISTER (MultiModel);
418SHAREDPTR (MultiModel);
419
420//! conversion of outer ARX model (mlnorm) to state space model
421/*!
422The model is constructed as:
423\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
424For example, for:
425Using Frobenius form, see [].
426
427For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
428*/
429//template<class sq_T>
430class StateCanonical: public StateSpace<fsqmat>{
431        protected:
432                //! remember connection from theta ->A
433                datalink_part th2A;
434                //! remember connection from theta ->C
435                datalink_part th2C;
436                //! remember connection from theta ->D
437                datalink_part th2D;
438                //!cached first row of A
439                vec A1row;
440                //!cached first row of C
441                vec C1row;
442                //!cached first row of D
443                vec D1row;
444               
445        public:
446                //! set up this object to match given mlnorm
447                void connect_mlnorm(const mlnorm<fsqmat> &ml){
448                        //get ids of yrv                               
449                        const RV &yrv = ml._rv();
450                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
451                        RV rgr0 = ml._rvc().remove_time();
452                        RV urv = rgr0.subt(yrv); 
453                       
454                        //We can do only 1d now... :(
455                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
456
457                        // create names for
458                        RV xrv; //empty
459                        RV Crv; //empty
460                        int td=ml._rvc().mint();
461                        // assuming strictly proper function!!!
462                        for (int t=-1;t>=td;t--){
463                                xrv.add(yrv.copy_t(t));
464                                Crv.add(urv.copy_t(t));
465                        }
466                                               
467                        // get mapp
468                        th2A.set_connection(xrv, ml._rvc());
469                        th2C.set_connection(Crv, ml._rvc());
470                        th2D.set_connection(urv, ml._rvc());
471
472                        //set matrix sizes
473                        this->A=zeros(xrv._dsize(),xrv._dsize());
474                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
475                                this->B=zeros(xrv._dsize(),1);
476                                this->B(0) = 1.0;
477                                this->C=zeros(1,xrv._dsize());
478                                this->D=zeros(1,urv._dsize());
479                                this->Q = zeros(xrv._dsize(),xrv._dsize());
480                        // R is set by update
481                       
482                        //set cache
483                        this->A1row = zeros(xrv._dsize());
484                        this->C1row = zeros(xrv._dsize());
485                        this->D1row = zeros(urv._dsize());
486                       
487                        update_from(ml);
488                        validate();
489                };
490                //! fast function to update parameters from ml - not checked for compatibility!!
491                void update_from(const mlnorm<fsqmat> &ml){
492                       
493                        vec theta = ml._A().get_row(0); // this
494                       
495                        th2A.filldown(theta,A1row);
496                        th2C.filldown(theta,C1row);
497                        th2D.filldown(theta,D1row);
498
499                        R = ml._R();
500
501                        A.set_row(0,A1row);
502                        C.set_row(0,C1row+D1row*A1row);
503                        D.set_row(0,D1row);
504                       
505                }
506};
507
508/////////// INSTANTIATION
509
510template<class sq_T>
511void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
512{
513
514        A = A0;
515        B = B0;
516        C = C0;
517        D = D0;
518        R = R0;
519        Q = Q0;
520        validate();
521}
522
523template<class sq_T>
524void StateSpace<sq_T>::validate(){
525        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
526        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
527        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
528        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
529        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
530        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
531}
532
533}
534#endif // KF_H
535
Note: See TracBrowser for help on using the browser.