root/library/bdm/estim/kalman.h @ 703

Revision 703, 16.9 kB (checked in by smidl, 15 years ago)

New transformation between StateSpace? and ARX

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21//#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h>
22
23namespace bdm
24{
25
26/*!
27 * \brief Basic elements of linear state-space model
28
29Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
30Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
31Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
32 */
33template<class sq_T>
34class StateSpace
35{
36        protected:
37                //! Matrix A
38                mat A;
39                //! Matrix B
40                mat B;
41                //! Matrix C
42                mat C;
43                //! Matrix D
44                mat D;
45                //! Matrix Q in square-root form
46                sq_T Q;
47                //! Matrix R in square-root form
48                sq_T R;
49        public:
50                StateSpace() :  A(), B(), C(), D(), Q(), R() {}
51                //!copy constructor
52                StateSpace(const StateSpace<sq_T> &S0) :  A(S0.A), B(S0.B), C(S0.C), D(S0.D), Q(S0.Q), R(S0.R) {}
53                //! set all matrix parameters
54                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
55                //! validation
56                void validate();
57                //! not virtual in this case
58                void from_setting (const Setting &set) {
59                        UI::get (A, set, "A", UI::compulsory);
60                        UI::get (B, set, "B", UI::compulsory);
61                        UI::get (C, set, "C", UI::compulsory);
62                        UI::get (D, set, "D", UI::compulsory);
63                        mat Qtm, Rtm;
64                        if(!UI::get(Qtm, set, "Q", UI::optional)){
65                                vec dq;
66                                UI::get(dq, set, "dQ", UI::compulsory);
67                                Qtm=diag(dq);
68                        }
69                        if(!UI::get(Rtm, set, "R", UI::optional)){
70                                vec dr;
71                                UI::get(dr, set, "dQ", UI::compulsory);
72                                Rtm=diag(dr);
73                        }
74                        R=Rtm; // automatic conversion
75                        Q=Qtm; 
76                       
77                        validate();
78                }               
79                //! access function
80                const mat& _A() const {return A;}
81                //! access function
82                const mat& _B()const {return B;}
83                //! access function
84                const mat& _C()const {return C;}
85                //! access function
86                const mat& _D()const {return D;}
87                //! access function
88                const sq_T& _Q()const {return Q;}
89                //! access function
90                const sq_T& _R()const {return R;}
91};
92
93//! Common abstract base for Kalman filters
94template<class sq_T>
95class Kalman: public BM, public StateSpace<sq_T>
96{
97        protected:
98                //! id of output
99                RV yrv;
100                //! Kalman gain
101                mat  _K;
102                //!posterior
103                enorm<sq_T> est;
104                //!marginal on data f(y|y)
105                enorm<sq_T>  fy;
106        public:
107                Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est(){}
108                //! Copy constructor
109                Kalman<sq_T>(const Kalman<sq_T> &K0) : BM(K0), StateSpace<sq_T>(K0), yrv(K0.yrv), _K(K0._K),  est(K0.est), fy(K0.fy){}
110                //!set statistics of the posterior
111                void set_statistics (const vec &mu0, const mat &P0) {est.set_parameters (mu0, P0); };
112                //!set statistics of the posterior
113                void set_statistics (const vec &mu0, const sq_T &P0) {est.set_parameters (mu0, P0); };
114                //! return correctly typed posterior (covariant return)
115                const enorm<sq_T>& posterior() const {return est;}
116                //! load basic elements of Kalman from structure
117                void from_setting (const Setting &set) {
118                        StateSpace<sq_T>::from_setting(set);
119                                               
120                        mat P0; vec mu0;
121                        UI::get(mu0, set, "mu0", UI::optional);
122                        UI::get(P0, set,  "P0", UI::optional);
123                        set_statistics(mu0,P0);
124                        // Initial values
125                        UI::get (yrv, set, "yrv", UI::optional);
126                        UI::get (rvc, set, "urv", UI::optional);
127                        set_yrv(concat(yrv,rvc));
128                       
129                        validate();
130                }
131                //! validate object
132                void validate() {
133                        StateSpace<sq_T>::validate();
134                        dimy = this->C.rows();
135                        dimc = this->B.cols();
136                        set_dim(this->A.rows());
137                       
138                        bdm_assert(est.dimension(), "Statistics and model parameters mismatch");
139                }
140};
141/*!
142* \brief Basic Kalman filter with full matrices
143*/
144
145class KalmanFull : public Kalman<fsqmat>
146{
147        public:
148                //! For EKFfull;
149                KalmanFull() :Kalman<fsqmat>(){};
150                //! Here dt = [yt;ut] of appropriate dimensions
151                void bayes (const vec &yt, const vec &cond=empty_vec);
152                BM* _copy_() const {
153                        KalmanFull* K = new KalmanFull;
154                        K->set_parameters (A, B, C, D, Q, R);
155                        K->set_statistics (est._mu(), est._R());
156                        return K;
157                }
158};
159UIREGISTER(KalmanFull);
160
161
162/*! \brief Kalman filter in square root form
163
164Trivial example:
165\include kalman_simple.cpp
166
167Complete constructor:
168*/
169class KalmanCh : public Kalman<chmat>
170{
171        protected:
172                //! @{ \name Internal storage - needs initialize()
173                //! pre array (triangular matrix)
174                mat preA;
175                //! post array (triangular matrix)
176                mat postA;
177                //!@}
178        public:
179                //! copy constructor
180                BM* _copy_() const {
181                        KalmanCh* K = new KalmanCh;
182                        K->set_parameters (A, B, C, D, Q, R);
183                        K->set_statistics (est._mu(), est._R());
184                        K->validate();
185                        return K;
186                }
187                //! set parameters for adapt from Kalman
188                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
189                //! initialize internal parametetrs
190                void initialize();
191
192                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
193
194                The following equality hold::\f[
195                \left[\begin{array}{cc}
196                R^{0.5}\\
197                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
198                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
199                R_{y}^{0.5} & KA'\\
200                & P_{t+1|t}^{0.5}\\
201                \\\end{array}\right]\f]
202
203                Thus this object evaluates only predictors! Not filtering densities.
204                */
205                void bayes (const vec &yt, const vec &cond=empty_vec);
206
207                void from_setting(const Setting &set){
208                        Kalman<chmat>::from_setting(set);
209                        validate();
210                }
211                void validate() {
212                        Kalman<chmat>::validate();
213                        initialize();
214                }
215};
216UIREGISTER(KalmanCh);
217
218/*!
219\brief Extended Kalman Filter in full matrices
220
221An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
222*/
223class EKFfull : public KalmanFull
224{
225        protected:
226                //! Internal Model f(x,u)
227                shared_ptr<diffbifn> pfxu;
228
229                //! Observation Model h(x,u)
230                shared_ptr<diffbifn> phxu;
231
232        public:
233                //! Default constructor
234                EKFfull ();
235
236                //! Set nonlinear functions for mean values and covariance matrices.
237                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
238
239                //! Here dt = [yt;ut] of appropriate dimensions
240                void bayes (const vec &yt, const vec &cond=empty_vec);
241                //! set estimates
242                void set_statistics (const vec &mu0, const mat &P0) {
243                        est.set_parameters (mu0, P0);
244                };
245                //! access function
246                const mat _R() {
247                        return est._R().to_mat();
248                }
249                void from_setting (const Setting &set) {                       
250                        BM::from_setting(set);
251                        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
252                        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
253                       
254                        //statistics
255                        int dim = IM->dimension();
256                        vec mu0;
257                        if ( !UI::get ( mu0, set, "mu0" ) )
258                                mu0 = zeros ( dim );
259                       
260                        mat P0;
261                        vec dP0;
262                        if ( UI::get ( dP0, set, "dP0" ) )
263                                P0 = diag ( dP0 );
264                        else if ( !UI::get ( P0, set, "P0" ) )
265                                P0 = eye ( dim );
266                       
267                        set_statistics ( mu0, P0 );
268                       
269                        //parameters
270                        vec dQ, dR;
271                        UI::get ( dQ, set, "dQ", UI::compulsory );
272                        UI::get ( dR, set, "dR", UI::compulsory );
273                        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
274                                               
275                        string options;
276                        if ( UI::get ( options, set, "options" ) )
277                                set_options ( options );
278//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
279//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
280//                     
281//                      mat R0;
282//                      UI::get(R0, set, "R",UI::compulsory);
283//                      mat Q0;
284//                      UI::get(Q0, set, "Q",UI::compulsory);
285//                     
286//                     
287//                      mat P0; vec mu0;
288//                      UI::get(mu0, set, "mu0", UI::optional);
289//                      UI::get(P0, set,  "P0", UI::optional);
290//                      set_statistics(mu0,P0);
291//                      // Initial values
292//                      UI::get (yrv, set, "yrv", UI::optional);
293//                      UI::get (urv, set, "urv", UI::optional);
294//                      set_drv(concat(yrv,urv));
295//
296//                      // setup StateSpace
297//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
298//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
299//                     
300                        validate();
301                }
302                void validate() {
303                        // check stats and IM and OM
304                }
305};
306UIREGISTER(EKFfull);
307
308
309/*!
310\brief Extended Kalman Filter in Square root
311
312An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
313*/
314
315class EKFCh : public KalmanCh
316{
317        protected:
318                //! Internal Model f(x,u)
319                shared_ptr<diffbifn> pfxu;
320
321                //! Observation Model h(x,u)
322                shared_ptr<diffbifn> phxu;
323        public:
324                //! copy constructor duplicated - calls different set_parameters
325                BM* _copy_() const {
326                        EKFCh* E = new EKFCh;
327                        E->set_parameters (pfxu, phxu, Q, R);
328                        E->set_statistics (est._mu(), est._R());
329                        return E;
330                }
331                //! Set nonlinear functions for mean values and covariance matrices.
332                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
333
334                //! Here dt = [yt;ut] of appropriate dimensions
335                void bayes (const vec &yt, const vec &cond=empty_vec);
336
337                void from_setting (const Setting &set);
338
339                void validate(){};
340                // TODO dodelat void to_setting( Setting &set ) const;
341
342};
343
344UIREGISTER (EKFCh);
345SHAREDPTR (EKFCh);
346
347
348//////// INstance
349
350/*! \brief (Switching) Multiple Model
351The model runs several models in parallel and evaluates thier weights (fittness).
352
353The statistics of the resulting density are merged using (geometric?) combination.
354
355The next step is performed with the new statistics for all models.
356*/
357class MultiModel: public BM
358{
359        protected:
360                //! List of models between which we switch
361                Array<EKFCh*> Models;
362                //! vector of model weights
363                vec w;
364                //! cache of model lls
365                vec _lls;
366                //! type of switching policy [1=maximum,2=...]
367                int policy;
368                //! internal statistics
369                enorm<chmat> est;
370        public:
371                //! set internal parameters
372                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
373                        Models = A;//TODO: test if evalll is set
374                        w.set_length (A.length());
375                        _lls.set_length (A.length());
376                        policy = pol0;
377
378                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
379                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
380                }
381                void bayes (const vec &yt, const vec &cond=empty_vec) {
382                        int n = Models.length();
383                        int i;
384                        for (i = 0; i < n; i++) {
385                                Models (i)->bayes (yt);
386                                _lls (i) = Models (i)->_ll();
387                        }
388                        double mlls = max (_lls);
389                        w = exp (_lls - mlls);
390                        w /= sum (w);   //normalization
391                        //set statistics
392                        switch (policy) {
393                                case 1: {
394                                        int mi = max_index (w);
395                                        const enorm<chmat> &st = Models (mi)->posterior() ;
396                                        est.set_parameters (st.mean(), st._R());
397                                }
398                                break;
399                                default:
400                                        bdm_error ("unknown policy");
401                        }
402                        // copy result to all models
403                        for (i = 0; i < n; i++) {
404                                Models (i)->set_statistics (est.mean(), est._R());
405                        }
406                }
407                //! return correctly typed posterior (covariant return)
408                const enorm<chmat>& posterior() const {
409                        return est;
410                }
411
412                void from_setting (const Setting &set);
413
414                // TODO dodelat void to_setting( Setting &set ) const;
415
416};
417
418UIREGISTER (MultiModel);
419SHAREDPTR (MultiModel);
420
421//! conversion of outer ARX model (mlnorm) to state space model
422/*!
423The model is constructed as:
424\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
425For example, for:
426Using Frobenius form, see [].
427
428For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
429*/
430//template<class sq_T>
431class StateCanonical: public StateSpace<fsqmat>{
432        protected:
433                //! remember connection from theta ->A
434                datalink_part th2A;
435                //! remember connection from theta ->C
436                datalink_part th2C;
437                //! remember connection from theta ->D
438                datalink_part th2D;
439                //!cached first row of A
440                vec A1row;
441                //!cached first row of C
442                vec C1row;
443                //!cached first row of D
444                vec D1row;
445               
446        public:
447                //! set up this object to match given mlnorm
448                void connect_mlnorm(const mlnorm<fsqmat> &ml){
449                        //get ids of yrv                               
450                        const RV &yrv = ml._rv();
451                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
452                        RV rgr0 = ml._rvc().remove_time();
453                        RV urv = rgr0.subt(yrv); 
454                       
455                        //We can do only 1d now... :(
456                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
457
458                        // create names for
459                        RV xrv; //empty
460                        RV Crv; //empty
461                        int td=ml._rvc().mint();
462                        // assuming strictly proper function!!!
463                        for (int t=-1;t>=td;t--){
464                                xrv.add(yrv.copy_t(t));
465                                Crv.add(urv.copy_t(t));
466                        }
467                                               
468                        // get mapp
469                        th2A.set_connection(xrv, ml._rvc());
470                        th2C.set_connection(Crv, ml._rvc());
471                        th2D.set_connection(urv, ml._rvc());
472
473                        //set matrix sizes
474                        this->A=zeros(xrv._dsize(),xrv._dsize());
475                        for (int j=1; j<xrv._dsize(); j++){A(j,j-1)=1.0;} // off diagonal
476                                this->B=zeros(xrv._dsize(),1);
477                                this->B(0) = 1.0;
478                                this->C=zeros(1,xrv._dsize());
479                                this->D=zeros(1,urv._dsize());
480                                this->Q = zeros(xrv._dsize(),xrv._dsize());
481                        // R is set by update
482                       
483                        //set cache
484                        this->A1row = zeros(xrv._dsize());
485                        this->C1row = zeros(xrv._dsize());
486                        this->D1row = zeros(urv._dsize());
487                       
488                        update_from(ml);
489                        validate();
490                };
491                //! fast function to update parameters from ml - not checked for compatibility!!
492                void update_from(const mlnorm<fsqmat> &ml){
493                       
494                        vec theta = ml._A().get_row(0); // this
495                       
496                        th2A.filldown(theta,A1row);
497                        th2C.filldown(theta,C1row);
498                        th2D.filldown(theta,D1row);
499
500                        R = ml._R();
501
502                        A.set_row(0,A1row);
503                        C.set_row(0,C1row+D1row(0)*A1row);
504                        D.set_row(0,D1row);
505                       
506                }
507};
508/*!
509State-Space representation of multivariate autoregressive model.
510The original model:
511\f[ y_t = \theta [\ldots y_{t-k}, \ldots u_{t-l}, \ldots z_{t-m}]' + \Sigma^{-1/2} e_t \f]
512where \f$ k,l,m \f$ are maximum delayes of corresponding variables in the regressor.
513
514The transformed state is:
515\f[ x_t = [y_{t} \ldots y_{t-k-1}, u_{t} \ldots u_{t-l-1}, z_{t} \ldots z_{t-m-1}]\f]
516
517The state accumulates all delayed values starting from time \f$ t \f$ .
518
519
520*/
521class StateFromARX: public StateSpace<fsqmat>{
522        protected:
523                //! remember connection from theta ->A
524                datalink_part th2A;
525                //! xrv
526                RV xrv;
527                //!function adds n diagonal elements from given starting point r,c
528                void diagonal_part(mat &A, int r, int c, int n){
529                        for(int i=0; i<n; i++){A(r,c)=1.0; r++;c++;}
530                };
531        public:
532                //! set up this object to match given mlnorm
533                void connect_mlnorm(const mlnorm<fsqmat> &ml){
534                        //get ids of yrv                               
535                        const RV &yrv = ml._rv();
536                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
537                        const RV &rgr = ml._rvc();
538                        RV rgr0 = rgr.remove_time();
539                        RV urv = rgr0.subt(yrv); 
540                                               
541                        // create names for state variables
542                        xrv=yrv; 
543                       
544                        int y_multiplicity = rgr.mint(yrv);
545                        int y_block_size = yrv.length()*(y_multiplicity+1); // current yt + all delayed yts
546                        for (int m=0;m<y_multiplicity;m++){ 
547                                xrv.add(yrv.copy_t(-m-1)); //add delayed yt
548                        }
549                        // add regressors
550                        ivec u_block_sizes(urv.length()); // no_blocks = yt + unique rgr
551                        for (int r=0;r<urv.length();r++){
552                                RV R=urv.subselect(vec_1(r)); //non-delayed element of rgr
553                                int r_size=urv.size(r);
554                                int r_multiplicity=rgr.mint(R);
555                                u_block_sizes(r)= r_size*r_multiplicity ;
556                                for(int m=0;m<r_multiplicity;m++){ 
557                                        xrv.add(R.copy_t(-m-1)); //add delayed yt
558                                }
559                        }
560                       
561                        // get mapp
562                        th2A.set_connection(xrv, ml._rvc());
563                       
564                        //set matrix sizes
565                        this->A=zeros(xrv._dsize(),xrv._dsize());
566                        //create y block
567                        diagonal_part(this->A, yrv._dsize(), 0, y_block_size-yrv._dsize() );
568                       
569                        this->B=zeros(xrv._dsize(), urv._dsize());
570                        //add diagonals for rgr
571                        int active_x=y_block_size;
572                        for (int r=0; r<urv.length(); r++){
573                                diagonal_part( this->A, active_x+urv.size(r),active_x, u_block_sizes(r)-urv.size(r));
574                                this->B.set_submatrix(active_x, 0, eye(urv.size(r)));
575                                active_x+=u_block_sizes(r);
576                        }
577                        this->C=zeros(yrv._dsize(), xrv._dsize());
578                        this->C.set_submatrix(0,0,eye(yrv._dsize()));
579                                this->D=zeros(yrv._dsize(),urv._dsize());
580                                this->R = zeros(yrv._dsize(),yrv._dsize());
581                                this->Q = zeros(xrv._dsize(),xrv._dsize());
582                                // Q is set by update
583                                                               
584                                update_from(ml);
585                                validate();
586                };
587                //! fast function to update parameters from ml - not checked for compatibility!!
588                void update_from(const mlnorm<fsqmat> &ml){
589                       
590                        vec theta = ml._A().get_row(0); // this
591                        this->Q._M().set_submatrix(0,0,ml._R());
592                       
593                }
594};
595
596/////////// INSTANTIATION
597
598template<class sq_T>
599void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
600{
601
602        A = A0;
603        B = B0;
604        C = C0;
605        D = D0;
606        R = R0;
607        Q = Q0;
608        validate();
609}
610
611template<class sq_T>
612void StateSpace<sq_T>::validate(){
613        bdm_assert (A.cols() == A.rows(), "KalmanFull: A is not square");
614        bdm_assert (B.rows() == A.rows(), "KalmanFull: B is not compatible");
615        bdm_assert (C.cols() == A.rows(), "KalmanFull: C is not compatible");
616        bdm_assert ( (D.rows() == C.rows()) && (D.cols() == B.cols()), "KalmanFull: D is not compatible");
617        bdm_assert ( (Q.cols() == A.rows()) && (Q.rows() == A.rows()), "KalmanFull: Q is not compatible");
618        bdm_assert ( (R.cols() == C.rows()) && (R.rows() == C.rows()), "KalmanFull: R is not compatible");
619}
620
621}
622#endif // KF_H
623
Note: See TracBrowser for help on using the browser.