root/library/bdm/estim/kalman.h @ 963

Revision 963, 14.8 kB (checked in by smidl, 14 years ago)

Kalman loading corrections

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21//#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h>
22
23namespace bdm {
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_{t+1} = A x_{t} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t} + C u_t + R^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are mutually independent vectors of Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace {
34protected:
35        //! Matrix A
36        mat A;
37        //! Matrix B
38        mat B;
39        //! Matrix C
40        mat C;
41        //! Matrix D
42        mat D;
43        //! Matrix Q in square-root form
44        sq_T Q;
45        //! Matrix R in square-root form
46        sq_T R;
47public:
48        StateSpace() :  A(), B(), C(), D(), Q(), R() {}
49        //!copy constructor
50        StateSpace ( const StateSpace<sq_T> &S0 ) :  A ( S0.A ), B ( S0.B ), C ( S0.C ), D ( S0.D ), Q ( S0.Q ), R ( S0.R ) {}
51        //! set all matrix parameters
52        void set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 );
53        //! validation
54        void validate();
55        //! not virtual in this case
56        void from_setting ( const Setting &set ) {
57                UI::get ( A, set, "A", UI::compulsory );
58                UI::get ( B, set, "B", UI::compulsory );
59                UI::get ( C, set, "C", UI::compulsory );
60                UI::get ( D, set, "D", UI::compulsory );
61                mat Qtm, Rtm; // full matrices
62                if ( !UI::get ( Qtm, set, "Q", UI::optional ) ) {
63                        vec dq;
64                        UI::get ( dq, set, "dQ", UI::compulsory );
65                        Qtm = diag ( dq );
66                }
67                if ( !UI::get ( Rtm, set, "R", UI::optional ) ) {
68                        vec dr;
69                        UI::get ( dr, set, "dQ", UI::compulsory );
70                        Rtm = diag ( dr );
71                }
72                R = Rtm; // automatic conversion to square-root form
73                Q = Qtm;
74
75                validate();
76        }
77        void to_setting ( Setting &set ) const {
78                UI::save( A, set, "A" );
79                UI::save( B, set, "B" );
80                UI::save( C, set, "C" );
81                UI::save( D, set, "D" );
82                UI::save( Q.to_mat(), set, "Q" );
83                UI::save( R.to_mat(), set, "R" );
84        }
85        //! access function
86        const mat& _A() const {
87                return A;
88        }
89        //! access function
90        const mat& _B() const {
91                return B;
92        }
93        //! access function
94        const mat& _C() const {
95                return C;
96        }
97        //! access function
98        const mat& _D() const {
99                return D;
100        }
101        //! access function
102        const sq_T& _Q() const {
103                return Q;
104        }
105        //! access function
106        const sq_T& _R() const {
107                return R;
108        }
109};
110
111//! Common abstract base for Kalman filters
112template<class sq_T>
113class Kalman: public BM, public StateSpace<sq_T> {
114protected:
115        //! id of output
116        RV yrv;
117        //! Kalman gain
118        mat  _K;
119        //!posterior
120        enorm<sq_T> est;
121        //!marginal on data f(y|y)
122        enorm<sq_T>  fy;
123public:
124        Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est() {}
125        //! Copy constructor
126        Kalman<sq_T> ( const Kalman<sq_T> &K0 ) : BM ( K0 ), StateSpace<sq_T> ( K0 ), yrv ( K0.yrv ), _K ( K0._K ),  est ( K0.est ), fy ( K0.fy ) {}
127        //!set statistics of the posterior
128        void set_statistics ( const vec &mu0, const mat &P0 ) {
129                est.set_parameters ( mu0, P0 );
130        };
131        //!set statistics of the posterior
132        void set_statistics ( const vec &mu0, const sq_T &P0 ) {
133                est.set_parameters ( mu0, P0 );
134        };
135        //! return correctly typed posterior (covariant return)
136        const enorm<sq_T>& posterior() const {
137                return est;
138        }
139        //! load basic elements of Kalman from structure
140        /*! \code
141            class = 'KalmanFull';
142                A     = [];                   // Matrix A
143                B     = [];                   // Matrix B
144                C     = [];                   // Matrix C
145                D     = [];                   // Matrix D
146                Q     = [];                   // Matrix Q
147                R     = [];                   // Matrix R
148                prior = struct('class','epdf_offspring');    // Prior density - will be converted to gaussian
149                rvy   = RV('some_names');     // Description of required observations
150                rvc   = RV('some_names');     // Description of required inputs
151                \endcode
152               
153                */
154        void from_setting ( const Setting &set ) {
155                StateSpace<sq_T>::from_setting ( set );
156                BM::from_setting(set);
157               
158                shared_ptr<epdf> pri=UI::build<epdf>(set,"prior",UI::compulsory);
159                //bdm_assert(pri->dimension()==);
160                set_statistics ( pri->mean(), pri->covariance() );
161        }
162        void to_setting ( Setting &set ) const {
163                StateSpace<sq_T>::to_setting ( set );
164                BM::to_setting(set);
165
166                UI::save(est, set, "prior");
167        }
168        //! validate object
169        void validate() {
170                StateSpace<sq_T>::validate();
171                dimy = this->C.rows();
172                dimc = this->B.cols();
173                set_dim ( this->A.rows() );
174
175                bdm_assert ( est.dimension(), "Statistics and model parameters mismatch" );
176        }
177
178};
179/*!
180* \brief Basic Kalman filter with full matrices
181*/
182
183class KalmanFull : public Kalman<fsqmat> {
184public:
185        //! For EKFfull;
186        KalmanFull() : Kalman<fsqmat>() {};
187        //! Here dt = [yt;ut] of appropriate dimensions
188        void bayes ( const vec &yt, const vec &cond = empty_vec );
189
190        virtual KalmanFull* _copy() const {
191                KalmanFull* K = new KalmanFull;
192                K->set_parameters ( A, B, C, D, Q, R );
193                K->set_statistics ( est._mu(), est._R() );
194                return K;
195        }
196};
197UIREGISTER ( KalmanFull );
198
199
200/*! \brief Kalman filter in square root form
201
202Trivial example:
203\include kalman_simple.cpp
204
205Complete constructor:
206*/
207class KalmanCh : public Kalman<chmat> {
208protected:
209        //! @{ \name Internal storage - needs initialize()
210        //! pre array (triangular matrix)
211        mat preA;
212        //! post array (triangular matrix)
213        mat postA;
214        //!@}
215public:
216        //! copy constructor
217        virtual KalmanCh* _copy() const {
218                KalmanCh* K = new KalmanCh;
219                K->set_parameters ( A, B, C, D, Q, R );
220                K->set_statistics ( est._mu(), est._R() );
221                K->validate();
222                return K;
223        }
224        //! set parameters for adapt from Kalman
225        void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0 );
226        //! initialize internal parametetrs
227        void initialize();
228
229        /*!\brief  Here dt = [yt;ut] of appropriate dimensions
230
231        The following equality hold::\f[
232        \left[\begin{array}{cc}
233        R^{0.5}\\
234        P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
235        & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
236        R_{y}^{0.5} & KA'\\
237        & P_{t+1|t}^{0.5}\\
238        \\\end{array}\right]\f]
239
240        Thus this object evaluates only predictors! Not filtering densities.
241        */
242        void bayes ( const vec &yt, const vec &cond = empty_vec );
243
244        void from_setting ( const Setting &set ) {
245                Kalman<chmat>::from_setting ( set );
246                validate();
247        }
248        void validate() {
249                Kalman<chmat>::validate();
250                initialize();
251        }
252};
253UIREGISTER ( KalmanCh );
254
255/*!
256\brief Extended Kalman Filter in full matrices
257
258An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
259*/
260class EKFfull : public KalmanFull {
261protected:
262        //! Internal Model f(x,u)
263        shared_ptr<diffbifn> pfxu;
264
265        //! Observation Model h(x,u)
266        shared_ptr<diffbifn> phxu;
267
268public:
269        //! Default constructor
270        EKFfull ();
271
272        //! Set nonlinear functions for mean values and covariance matrices.
273        void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0 );
274
275        //! Here dt = [yt;ut] of appropriate dimensions
276        void bayes ( const vec &yt, const vec &cond = empty_vec );
277        //! set estimates
278        void set_statistics ( const vec &mu0, const mat &P0 ) {
279                est.set_parameters ( mu0, P0 );
280        };
281        //! access function
282        const mat _R() {
283                return est._R().to_mat();
284        }
285        void from_setting ( const Setting &set ) {
286                BM::from_setting ( set );
287                shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
288                shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
289
290                //statistics
291                int dim = IM->dimension();
292                vec mu0;
293                if ( !UI::get ( mu0, set, "mu0" ) )
294                        mu0 = zeros ( dim );
295
296                mat P0;
297                vec dP0;
298                if ( UI::get ( dP0, set, "dP0" ) )
299                        P0 = diag ( dP0 );
300                else if ( !UI::get ( P0, set, "P0" ) )
301                        P0 = eye ( dim );
302
303                set_statistics ( mu0, P0 );
304
305                //parameters
306                vec dQ, dR;
307                UI::get ( dQ, set, "dQ", UI::compulsory );
308                UI::get ( dR, set, "dR", UI::compulsory );
309                set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
310
311//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
312//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
313//
314//                      mat R0;
315//                      UI::get(R0, set, "R",UI::compulsory);
316//                      mat Q0;
317//                      UI::get(Q0, set, "Q",UI::compulsory);
318//
319//
320//                      mat P0; vec mu0;
321//                      UI::get(mu0, set, "mu0", UI::optional);
322//                      UI::get(P0, set,  "P0", UI::optional);
323//                      set_statistics(mu0,P0);
324//                      // Initial values
325//                      UI::get (yrv, set, "yrv", UI::optional);
326//                      UI::get (urv, set, "urv", UI::optional);
327//                      set_drv(concat(yrv,urv));
328//
329//                      // setup StateSpace
330//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
331//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
332//
333        }
334
335        void validate() {
336                KalmanFull::validate();
337
338                // check stats and IM and OM
339        }
340};
341UIREGISTER ( EKFfull );
342
343
344/*!
345\brief Extended Kalman Filter in Square root
346
347An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
348*/
349
350class EKFCh : public KalmanCh {
351protected:
352        //! Internal Model f(x,u)
353        shared_ptr<diffbifn> pfxu;
354
355        //! Observation Model h(x,u)
356        shared_ptr<diffbifn> phxu;
357public:
358        //! copy constructor duplicated - calls different set_parameters
359        EKFCh* _copy() const {
360                return new EKFCh(*this);
361        }
362        //! Set nonlinear functions for mean values and covariance matrices.
363        void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0 );
364
365        //! Here dt = [yt;ut] of appropriate dimensions
366        void bayes ( const vec &yt, const vec &cond = empty_vec );
367
368        void from_setting ( const Setting &set );
369
370        void validate() {};
371        // TODO dodelat void to_setting( Setting &set ) const;
372
373};
374
375UIREGISTER ( EKFCh );
376SHAREDPTR ( EKFCh );
377
378
379//////// INstance
380
381/*! \brief (Switching) Multiple Model
382The model runs several models in parallel and evaluates thier weights (fittness).
383
384The statistics of the resulting density are merged using (geometric?) combination.
385
386The next step is performed with the new statistics for all models.
387*/
388class MultiModel: public BM {
389protected:
390        //! List of models between which we switch
391        Array<EKFCh*> Models;
392        //! vector of model weights
393        vec w;
394        //! cache of model lls
395        vec _lls;
396        //! type of switching policy [1=maximum,2=...]
397        int policy;
398        //! internal statistics
399        enorm<chmat> est;
400public:
401        //! set internal parameters
402        void set_parameters ( Array<EKFCh*> A, int pol0 = 1 ) {
403                Models = A;//TODO: test if evalll is set
404                w.set_length ( A.length() );
405                _lls.set_length ( A.length() );
406                policy = pol0;
407
408                est.set_rv ( RV ( "MM", A ( 0 )->posterior().dimension(), 0 ) );
409                est.set_parameters ( A ( 0 )->posterior().mean(), A ( 0 )->posterior()._R() );
410        }
411        void bayes ( const vec &yt, const vec &cond = empty_vec ) {
412                int n = Models.length();
413                int i;
414                for ( i = 0; i < n; i++ ) {
415                        Models ( i )->bayes ( yt );
416                        _lls ( i ) = Models ( i )->_ll();
417                }
418                double mlls = max ( _lls );
419                w = exp ( _lls - mlls );
420                w /= sum ( w ); //normalization
421                //set statistics
422                switch ( policy ) {
423                case 1: {
424                        int mi = max_index ( w );
425                        const enorm<chmat> &st = Models ( mi )->posterior() ;
426                        est.set_parameters ( st.mean(), st._R() );
427                }
428                break;
429                default:
430                        bdm_error ( "unknown policy" );
431                }
432                // copy result to all models
433                for ( i = 0; i < n; i++ ) {
434                        Models ( i )->set_statistics ( est.mean(), est._R() );
435                }
436        }
437        //! return correctly typed posterior (covariant return)
438        const enorm<chmat>& posterior() const {
439                return est;
440        }
441
442        void from_setting ( const Setting &set );
443
444};
445UIREGISTER ( MultiModel );
446SHAREDPTR ( MultiModel );
447
448//! conversion of outer ARX model (mlnorm) to state space model
449/*!
450The model is constructed as:
451\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
452For example, for:
453Using Frobenius form, see [].
454
455For easier use in the future, indices theta_in_A and theta_in_C are set. TODO - explain
456*/
457//template<class sq_T>
458class StateCanonical: public StateSpace<fsqmat> {
459protected:
460        //! remember connection from theta ->A
461        datalink_part th2A;
462        //! remember connection from theta ->C
463        datalink_part th2C;
464        //! remember connection from theta ->D
465        datalink_part th2D;
466        //!cached first row of A
467        vec A1row;
468        //!cached first row of C
469        vec C1row;
470        //!cached first row of D
471        vec D1row;
472
473public:
474        //! set up this object to match given mlnorm
475        void connect_mlnorm ( const mlnorm<fsqmat> &ml );
476
477        //! fast function to update parameters from ml - not checked for compatibility!!
478        void update_from ( const mlnorm<fsqmat> &ml );
479};
480/*!
481State-Space representation of multivariate autoregressive model.
482The original model:
483\f[ y_t = \theta [\ldots y_{t-k}, \ldots u_{t-l}, \ldots z_{t-m}]' + \Sigma^{-1/2} e_t \f]
484where \f$ k,l,m \f$ are maximum delayes of corresponding variables in the regressor.
485
486The transformed state is:
487\f[ x_t = [y_{t} \ldots y_{t-k-1}, u_{t} \ldots u_{t-l-1}, z_{t} \ldots z_{t-m-1}]\f]
488
489The state accumulates all delayed values starting from time \f$ t \f$ .
490
491
492*/
493class StateFromARX: public StateSpace<chmat> {
494protected:
495        //! remember connection from theta ->A
496        datalink_part th2A;
497        //! remember connection from theta ->B
498        datalink_part th2B;
499        //!function adds n diagonal elements from given starting point r,c
500        void diagonal_part ( mat &A, int r, int c, int n ) {
501                for ( int i = 0; i < n; i++ ) {
502                        A ( r, c ) = 1.0;
503                        r++;
504                        c++;
505                }
506        };
507        //! similar to ARX.have_constant
508        bool have_constant;
509public:
510        //! set up this object to match given mlnorm
511        //! Note that state-space and common mpdf use different meaning of \f$ _t \f$ in \f$ u_t \f$.
512        //!While mlnorm typically assumes that \f$ u_t \rightarrow y_t \f$ in state space it is \f$ u_{t-1} \rightarrow y_t \f$
513        //! For consequences in notation of internal variable xt see arx2statespace_notes.lyx.
514        void connect_mlnorm ( const mlnorm<chmat> &ml, RV &xrv, RV &urv );
515
516        //! fast function to update parameters from ml - not checked for compatibility!!
517        void update_from ( const mlnorm<chmat> &ml );
518
519        //! access function
520        bool _have_constant() const {
521                return have_constant;
522        }
523};
524
525/////////// INSTANTIATION
526
527template<class sq_T>
528void StateSpace<sq_T>::set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 ) {
529
530        A = A0;
531        B = B0;
532        C = C0;
533        D = D0;
534        R = R0;
535        Q = Q0;
536        validate();
537}
538
539template<class sq_T>
540void StateSpace<sq_T>::validate() {
541        bdm_assert ( A.cols() == A.rows(), "KalmanFull: A is not square" );
542        bdm_assert ( B.rows() == A.rows(), "KalmanFull: B is not compatible" );
543        bdm_assert ( C.cols() == A.rows(), "KalmanFull: C is not compatible" );
544        bdm_assert ( ( D.rows() == C.rows() ) && ( D.cols() == B.cols() ), "KalmanFull: D is not compatible" );
545        bdm_assert ( ( Q.cols() == A.rows() ) && ( Q.rows() == A.rows() ), "KalmanFull: Q is not compatible" );
546        bdm_assert ( ( R.cols() == C.rows() ) && ( R.rows() == C.rows() ), "KalmanFull: R is not compatible" );
547}
548
549}
550#endif // KF_H
551
Note: See TracBrowser for help on using the browser.