root/library/bdm/estim/kalman.h @ 780

Revision 766, 14.5 kB (checked in by mido, 14 years ago)

abstract methods restored wherever they are meaningful
macros NOT_IMPLEMENTED and NOT_IMPLEMENTED_VOID defined to make sources shorter
emix::set_parameters and mmix::set_parameters removed, corresponding acces methods created and the corresponding validate methods improved appropriately
some compilator warnings were avoided
and also a few other things cleaned up

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21//#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h>
22
23namespace bdm {
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_{t+1} = A x_{t} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t} + C u_t + R^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are mutually independent vectors of Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace {
34protected:
35        //! Matrix A
36        mat A;
37        //! Matrix B
38        mat B;
39        //! Matrix C
40        mat C;
41        //! Matrix D
42        mat D;
43        //! Matrix Q in square-root form
44        sq_T Q;
45        //! Matrix R in square-root form
46        sq_T R;
47public:
48        StateSpace() :  A(), B(), C(), D(), Q(), R() {}
49        //!copy constructor
50        StateSpace ( const StateSpace<sq_T> &S0 ) :  A ( S0.A ), B ( S0.B ), C ( S0.C ), D ( S0.D ), Q ( S0.Q ), R ( S0.R ) {}
51        //! set all matrix parameters
52        void set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 );
53        //! validation
54        void validate();
55        //! not virtual in this case
56        void from_setting ( const Setting &set ) {
57                UI::get ( A, set, "A", UI::compulsory );
58                UI::get ( B, set, "B", UI::compulsory );
59                UI::get ( C, set, "C", UI::compulsory );
60                UI::get ( D, set, "D", UI::compulsory );
61                mat Qtm, Rtm; // full matrices
62                if ( !UI::get ( Qtm, set, "Q", UI::optional ) ) {
63                        vec dq;
64                        UI::get ( dq, set, "dQ", UI::compulsory );
65                        Qtm = diag ( dq );
66                }
67                if ( !UI::get ( Rtm, set, "R", UI::optional ) ) {
68                        vec dr;
69                        UI::get ( dr, set, "dQ", UI::compulsory );
70                        Rtm = diag ( dr );
71                }
72                R = Rtm; // automatic conversion to square-root form
73                Q = Qtm;
74
75                validate();
76        }
77        //! access function
78        const mat& _A() const {
79                return A;
80        }
81        //! access function
82        const mat& _B() const {
83                return B;
84        }
85        //! access function
86        const mat& _C() const {
87                return C;
88        }
89        //! access function
90        const mat& _D() const {
91                return D;
92        }
93        //! access function
94        const sq_T& _Q() const {
95                return Q;
96        }
97        //! access function
98        const sq_T& _R() const {
99                return R;
100        }
101};
102
103//! Common abstract base for Kalman filters
104template<class sq_T>
105class Kalman: public BM, public StateSpace<sq_T> {
106protected:
107        //! id of output
108        RV yrv;
109        //! Kalman gain
110        mat  _K;
111        //!posterior
112        enorm<sq_T> est;
113        //!marginal on data f(y|y)
114        enorm<sq_T>  fy;
115public:
116        Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est() {}
117        //! Copy constructor
118        Kalman<sq_T> ( const Kalman<sq_T> &K0 ) : BM ( K0 ), StateSpace<sq_T> ( K0 ), yrv ( K0.yrv ), _K ( K0._K ),  est ( K0.est ), fy ( K0.fy ) {}
119        //!set statistics of the posterior
120        void set_statistics ( const vec &mu0, const mat &P0 ) {
121                est.set_parameters ( mu0, P0 );
122        };
123        //!set statistics of the posterior
124        void set_statistics ( const vec &mu0, const sq_T &P0 ) {
125                est.set_parameters ( mu0, P0 );
126        };
127        //! return correctly typed posterior (covariant return)
128        const enorm<sq_T>& posterior() const {
129                return est;
130        }
131        //! load basic elements of Kalman from structure
132        void from_setting ( const Setting &set ) {
133                StateSpace<sq_T>::from_setting ( set );
134
135                mat P0;
136                vec mu0;
137                UI::get ( mu0, set, "mu0", UI::optional );
138                UI::get ( P0, set,  "P0", UI::optional );
139                set_statistics ( mu0, P0 );
140                // Initial values
141                UI::get ( yrv, set, "yrv", UI::optional );
142                UI::get ( rvc, set, "urv", UI::optional );
143                set_yrv ( concat ( yrv, rvc ) );
144
145                validate();
146        }
147        //! validate object
148        void validate() {
149                StateSpace<sq_T>::validate();
150                dimy = this->C.rows();
151                dimc = this->B.cols();
152                set_dim ( this->A.rows() );
153
154                bdm_assert ( est.dimension(), "Statistics and model parameters mismatch" );
155        }
156
157        virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0);
158
159        virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL);
160       
161        virtual pdf* predictor() const NOT_IMPLEMENTED(NULL);
162};
163/*!
164* \brief Basic Kalman filter with full matrices
165*/
166
167class KalmanFull : public Kalman<fsqmat> {
168public:
169        //! For EKFfull;
170        KalmanFull() : Kalman<fsqmat>() {};
171        //! Here dt = [yt;ut] of appropriate dimensions
172        void bayes ( const vec &yt, const vec &cond = empty_vec );
173
174        virtual KalmanFull* _copy() const {
175                KalmanFull* K = new KalmanFull;
176                K->set_parameters ( A, B, C, D, Q, R );
177                K->set_statistics ( est._mu(), est._R() );
178                return K;
179        }
180};
181UIREGISTER ( KalmanFull );
182
183
184/*! \brief Kalman filter in square root form
185
186Trivial example:
187\include kalman_simple.cpp
188
189Complete constructor:
190*/
191class KalmanCh : public Kalman<chmat> {
192protected:
193        //! @{ \name Internal storage - needs initialize()
194        //! pre array (triangular matrix)
195        mat preA;
196        //! post array (triangular matrix)
197        mat postA;
198        //!@}
199public:
200        //! copy constructor
201        virtual KalmanCh* _copy() const {
202                KalmanCh* K = new KalmanCh;
203                K->set_parameters ( A, B, C, D, Q, R );
204                K->set_statistics ( est._mu(), est._R() );
205                K->validate();
206                return K;
207        }
208        //! set parameters for adapt from Kalman
209        void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0 );
210        //! initialize internal parametetrs
211        void initialize();
212
213        /*!\brief  Here dt = [yt;ut] of appropriate dimensions
214
215        The following equality hold::\f[
216        \left[\begin{array}{cc}
217        R^{0.5}\\
218        P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
219        & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
220        R_{y}^{0.5} & KA'\\
221        & P_{t+1|t}^{0.5}\\
222        \\\end{array}\right]\f]
223
224        Thus this object evaluates only predictors! Not filtering densities.
225        */
226        void bayes ( const vec &yt, const vec &cond = empty_vec );
227
228        void from_setting ( const Setting &set ) {
229                Kalman<chmat>::from_setting ( set );
230                validate();
231        }
232        void validate() {
233                Kalman<chmat>::validate();
234                initialize();
235        }
236};
237UIREGISTER ( KalmanCh );
238
239/*!
240\brief Extended Kalman Filter in full matrices
241
242An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
243*/
244class EKFfull : public KalmanFull {
245protected:
246        //! Internal Model f(x,u)
247        shared_ptr<diffbifn> pfxu;
248
249        //! Observation Model h(x,u)
250        shared_ptr<diffbifn> phxu;
251
252public:
253        //! Default constructor
254        EKFfull ();
255
256        //! Set nonlinear functions for mean values and covariance matrices.
257        void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0 );
258
259        //! Here dt = [yt;ut] of appropriate dimensions
260        void bayes ( const vec &yt, const vec &cond = empty_vec );
261        //! set estimates
262        void set_statistics ( const vec &mu0, const mat &P0 ) {
263                est.set_parameters ( mu0, P0 );
264        };
265        //! access function
266        const mat _R() {
267                return est._R().to_mat();
268        }
269        void from_setting ( const Setting &set ) {
270                BM::from_setting ( set );
271                shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
272                shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
273
274                //statistics
275                int dim = IM->dimension();
276                vec mu0;
277                if ( !UI::get ( mu0, set, "mu0" ) )
278                        mu0 = zeros ( dim );
279
280                mat P0;
281                vec dP0;
282                if ( UI::get ( dP0, set, "dP0" ) )
283                        P0 = diag ( dP0 );
284                else if ( !UI::get ( P0, set, "P0" ) )
285                        P0 = eye ( dim );
286
287                set_statistics ( mu0, P0 );
288
289                //parameters
290                vec dQ, dR;
291                UI::get ( dQ, set, "dQ", UI::compulsory );
292                UI::get ( dR, set, "dR", UI::compulsory );
293                set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
294
295                string options;
296                if ( UI::get ( options, set, "options" ) )
297                        set_options ( options );
298//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
299//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
300//
301//                      mat R0;
302//                      UI::get(R0, set, "R",UI::compulsory);
303//                      mat Q0;
304//                      UI::get(Q0, set, "Q",UI::compulsory);
305//
306//
307//                      mat P0; vec mu0;
308//                      UI::get(mu0, set, "mu0", UI::optional);
309//                      UI::get(P0, set,  "P0", UI::optional);
310//                      set_statistics(mu0,P0);
311//                      // Initial values
312//                      UI::get (yrv, set, "yrv", UI::optional);
313//                      UI::get (urv, set, "urv", UI::optional);
314//                      set_drv(concat(yrv,urv));
315//
316//                      // setup StateSpace
317//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
318//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
319//
320                validate();
321        }
322        void validate() {
323                // check stats and IM and OM
324        }
325};
326UIREGISTER ( EKFfull );
327
328
329/*!
330\brief Extended Kalman Filter in Square root
331
332An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
333*/
334
335class EKFCh : public KalmanCh {
336protected:
337        //! Internal Model f(x,u)
338        shared_ptr<diffbifn> pfxu;
339
340        //! Observation Model h(x,u)
341        shared_ptr<diffbifn> phxu;
342public:
343        //! copy constructor duplicated - calls different set_parameters
344        EKFCh* _copy() const {
345                return new EKFCh(*this);
346        }
347        //! Set nonlinear functions for mean values and covariance matrices.
348        void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0 );
349
350        //! Here dt = [yt;ut] of appropriate dimensions
351        void bayes ( const vec &yt, const vec &cond = empty_vec );
352
353        void from_setting ( const Setting &set );
354
355        void validate() {};
356        // TODO dodelat void to_setting( Setting &set ) const;
357
358};
359
360UIREGISTER ( EKFCh );
361SHAREDPTR ( EKFCh );
362
363
364//////// INstance
365
366/*! \brief (Switching) Multiple Model
367The model runs several models in parallel and evaluates thier weights (fittness).
368
369The statistics of the resulting density are merged using (geometric?) combination.
370
371The next step is performed with the new statistics for all models.
372*/
373class MultiModel: public BM {
374protected:
375        //! List of models between which we switch
376        Array<EKFCh*> Models;
377        //! vector of model weights
378        vec w;
379        //! cache of model lls
380        vec _lls;
381        //! type of switching policy [1=maximum,2=...]
382        int policy;
383        //! internal statistics
384        enorm<chmat> est;
385public:
386        //! set internal parameters
387        void set_parameters ( Array<EKFCh*> A, int pol0 = 1 ) {
388                Models = A;//TODO: test if evalll is set
389                w.set_length ( A.length() );
390                _lls.set_length ( A.length() );
391                policy = pol0;
392
393                est.set_rv ( RV ( "MM", A ( 0 )->posterior().dimension(), 0 ) );
394                est.set_parameters ( A ( 0 )->posterior().mean(), A ( 0 )->posterior()._R() );
395        }
396        void bayes ( const vec &yt, const vec &cond = empty_vec ) {
397                int n = Models.length();
398                int i;
399                for ( i = 0; i < n; i++ ) {
400                        Models ( i )->bayes ( yt );
401                        _lls ( i ) = Models ( i )->_ll();
402                }
403                double mlls = max ( _lls );
404                w = exp ( _lls - mlls );
405                w /= sum ( w ); //normalization
406                //set statistics
407                switch ( policy ) {
408                case 1: {
409                        int mi = max_index ( w );
410                        const enorm<chmat> &st = Models ( mi )->posterior() ;
411                        est.set_parameters ( st.mean(), st._R() );
412                }
413                break;
414                default:
415                        bdm_error ( "unknown policy" );
416                }
417                // copy result to all models
418                for ( i = 0; i < n; i++ ) {
419                        Models ( i )->set_statistics ( est.mean(), est._R() );
420                }
421        }
422        //! return correctly typed posterior (covariant return)
423        const enorm<chmat>& posterior() const {
424                return est;
425        }
426
427        void from_setting ( const Setting &set );
428
429        // TODO dodelat void to_setting( Setting &set ) const;
430
431        virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0); 
432
433        virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL); 
434
435        virtual pdf* predictor() const NOT_IMPLEMENTED(NULL); 
436};
437
438UIREGISTER ( MultiModel );
439SHAREDPTR ( MultiModel );
440
441//! conversion of outer ARX model (mlnorm) to state space model
442/*!
443The model is constructed as:
444\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
445For example, for:
446Using Frobenius form, see [].
447
448For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
449*/
450//template<class sq_T>
451class StateCanonical: public StateSpace<fsqmat> {
452protected:
453        //! remember connection from theta ->A
454        datalink_part th2A;
455        //! remember connection from theta ->C
456        datalink_part th2C;
457        //! remember connection from theta ->D
458        datalink_part th2D;
459        //!cached first row of A
460        vec A1row;
461        //!cached first row of C
462        vec C1row;
463        //!cached first row of D
464        vec D1row;
465
466public:
467        //! set up this object to match given mlnorm
468        void connect_mlnorm ( const mlnorm<fsqmat> &ml );
469
470        //! fast function to update parameters from ml - not checked for compatibility!!
471        void update_from ( const mlnorm<fsqmat> &ml );
472};
473/*!
474State-Space representation of multivariate autoregressive model.
475The original model:
476\f[ y_t = \theta [\ldots y_{t-k}, \ldots u_{t-l}, \ldots z_{t-m}]' + \Sigma^{-1/2} e_t \f]
477where \f$ k,l,m \f$ are maximum delayes of corresponding variables in the regressor.
478
479The transformed state is:
480\f[ x_t = [y_{t} \ldots y_{t-k-1}, u_{t} \ldots u_{t-l-1}, z_{t} \ldots z_{t-m-1}]\f]
481
482The state accumulates all delayed values starting from time \f$ t \f$ .
483
484
485*/
486class StateFromARX: public StateSpace<chmat> {
487protected:
488        //! remember connection from theta ->A
489        datalink_part th2A;
490        //! remember connection from theta ->B
491        datalink_part th2B;
492        //!function adds n diagonal elements from given starting point r,c
493        void diagonal_part ( mat &A, int r, int c, int n ) {
494                for ( int i = 0; i < n; i++ ) {
495                        A ( r, c ) = 1.0;
496                        r++;
497                        c++;
498                }
499        };
500        //! similar to ARX.have_constant
501        bool have_constant;
502public:
503        //! set up this object to match given mlnorm
504        //! Note that state-space and common mpdf use different meaning of \f$ _t \f$ in \f$ u_t \f$.
505        //!While mlnorm typically assumes that \f$ u_t \rightarrow y_t \f$ in state space it is \f$ u_{t-1} \rightarrow y_t \f$
506        //! For consequences in notation of internal variable xt see arx2statespace_notes.lyx.
507        void connect_mlnorm ( const mlnorm<chmat> &ml, RV &xrv, RV &urv );
508
509        //! fast function to update parameters from ml - not checked for compatibility!!
510        void update_from ( const mlnorm<chmat> &ml );
511
512        //! access function
513        bool _have_constant() const {
514                return have_constant;
515        }
516};
517
518/////////// INSTANTIATION
519
520template<class sq_T>
521void StateSpace<sq_T>::set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 ) {
522
523        A = A0;
524        B = B0;
525        C = C0;
526        D = D0;
527        R = R0;
528        Q = Q0;
529        validate();
530}
531
532template<class sq_T>
533void StateSpace<sq_T>::validate() {
534        bdm_assert ( A.cols() == A.rows(), "KalmanFull: A is not square" );
535        bdm_assert ( B.rows() == A.rows(), "KalmanFull: B is not compatible" );
536        bdm_assert ( C.cols() == A.rows(), "KalmanFull: C is not compatible" );
537        bdm_assert ( ( D.rows() == C.rows() ) && ( D.cols() == B.cols() ), "KalmanFull: D is not compatible" );
538        bdm_assert ( ( Q.cols() == A.rows() ) && ( Q.rows() == A.rows() ), "KalmanFull: Q is not compatible" );
539        bdm_assert ( ( R.cols() == C.rows() ) && ( R.rows() == C.rows() ), "KalmanFull: R is not compatible" );
540}
541
542}
543#endif // KF_H
544
Note: See TracBrowser for help on using the browser.