root/library/bdm/estim/kalman.h @ 1064

Revision 1064, 16.2 kB (checked in by mido, 14 years ago)

astyle applied all over the library

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21//#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h>
22
23namespace bdm {
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_{t+1} = A x_{t} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t} + C u_t + R^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are mutually independent vectors of Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace {
34protected:
35    //! Matrix A
36    mat A;
37    //! Matrix B
38    mat B;
39    //! Matrix C
40    mat C;
41    //! Matrix D
42    mat D;
43    //! Matrix Q in square-root form
44    sq_T Q;
45    //! Matrix R in square-root form
46    sq_T R;
47public:
48    StateSpace() :  A(), B(), C(), D(), Q(), R() {}
49    //!copy constructor
50    StateSpace ( const StateSpace<sq_T> &S0 ) :  A ( S0.A ), B ( S0.B ), C ( S0.C ), D ( S0.D ), Q ( S0.Q ), R ( S0.R ) {}
51    //! set all matrix parameters
52    void set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 );
53    //! validation
54    void validate();
55    //! not virtual in this case
56    void from_setting ( const Setting &set ) {
57        UI::get ( A, set, "A", UI::compulsory );
58        UI::get ( B, set, "B", UI::compulsory );
59        UI::get ( C, set, "C", UI::compulsory );
60        UI::get ( D, set, "D", UI::compulsory );
61        mat Qtm, Rtm; // full matrices
62        if ( !UI::get ( Qtm, set, "Q", UI::optional ) ) {
63            vec dq;
64            UI::get ( dq, set, "dQ", UI::compulsory );
65            Qtm = diag ( dq );
66        }
67        if ( !UI::get ( Rtm, set, "R", UI::optional ) ) {
68            vec dr;
69            UI::get ( dr, set, "dQ", UI::compulsory );
70            Rtm = diag ( dr );
71        }
72        R = Rtm; // automatic conversion to square-root form
73        Q = Qtm;
74
75        validate();
76    }
77    void to_setting ( Setting &set ) const {
78        UI::save( A, set, "A" );
79        UI::save( B, set, "B" );
80        UI::save( C, set, "C" );
81        UI::save( D, set, "D" );
82        UI::save( Q.to_mat(), set, "Q" );
83        UI::save( R.to_mat(), set, "R" );
84    }
85    //! access function
86    const mat& _A() const {
87        return A;
88    }
89    //! access function
90    const mat& _B() const {
91        return B;
92    }
93    //! access function
94    const mat& _C() const {
95        return C;
96    }
97    //! access function
98    const mat& _D() const {
99        return D;
100    }
101    //! access function
102    const sq_T& _Q() const {
103        return Q;
104    }
105    //! access function
106    const sq_T& _R() const {
107        return R;
108    }
109};
110
111//! Common abstract base for Kalman filters
112template<class sq_T>
113class Kalman: public BM, public StateSpace<sq_T> {
114protected:
115    //! id of output
116    RV yrv;
117    //! Kalman gain
118    mat  _K;
119    //!posterior
120    enorm<sq_T> est;
121    //!marginal on data f(y|y)
122    enorm<sq_T>  fy;
123public:
124    Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est() {}
125    //! Copy constructor
126    Kalman<sq_T> ( const Kalman<sq_T> &K0 ) : BM ( K0 ), StateSpace<sq_T> ( K0 ), yrv ( K0.yrv ), _K ( K0._K ),  est ( K0.est ), fy ( K0.fy ) {}
127    //!set statistics of the posterior
128    void set_statistics ( const vec &mu0, const mat &P0 ) {
129        est.set_parameters ( mu0, P0 );
130    };
131    //!set statistics of the posterior
132    void set_statistics ( const vec &mu0, const sq_T &P0 ) {
133        est.set_parameters ( mu0, P0 );
134    };
135    //! return correctly typed posterior (covariant return)
136    const enorm<sq_T>& posterior() const {
137        return est;
138    }
139    //! load basic elements of Kalman from structure
140    /*! \code
141        class = 'KalmanFull';
142        A     = [];                   // Matrix A
143        B     = [];                   // Matrix B
144        C     = [];                   // Matrix C
145        D     = [];                   // Matrix D
146        Q     = [];                   // Matrix Q
147        R     = [];                   // Matrix R
148        prior = struct('class','epdf_offspring');    // Prior density - will be converted to gaussian
149        yrv   = RV('some_names');     // Description of required observations
150        rvc   = RV('some_names');     // Description of required inputs
151        \endcode
152
153        */
154    void from_setting ( const Setting &set ) {
155        StateSpace<sq_T>::from_setting ( set );
156        BM::from_setting(set);
157
158        shared_ptr<epdf> pri=UI::build<epdf>(set,"prior",UI::compulsory);
159        //bdm_assert(pri->dimension()==);
160        set_statistics ( pri->mean(), pri->covariance() );
161    }
162    void to_setting ( Setting &set ) const {
163        StateSpace<sq_T>::to_setting ( set );
164        BM::to_setting(set);
165
166        UI::save(est, set, "prior");
167    }
168    //! validate object
169    void validate() {
170        StateSpace<sq_T>::validate();
171        dimy = this->C.rows();
172        dimc = this->B.cols();
173        set_dim ( this->A.rows() );
174
175        bdm_assert ( est.dimension(), "Statistics and model parameters mismatch" );
176    }
177
178};
179/*!
180* \brief Basic Kalman filter with full matrices
181*/
182
183class KalmanFull : public Kalman<fsqmat> {
184public:
185    //! For EKFfull;
186    KalmanFull() : Kalman<fsqmat>() {};
187    //! Here dt = [yt;ut] of appropriate dimensions
188    void bayes ( const vec &yt, const vec &cond = empty_vec );
189
190    virtual KalmanFull* _copy() const {
191        KalmanFull* K = new KalmanFull;
192        K->set_parameters ( A, B, C, D, Q, R );
193        K->set_statistics ( est._mu(), est._R() );
194        return K;
195    }
196};
197UIREGISTER ( KalmanFull );
198
199
200/*! \brief Kalman filter in square root form
201
202Trivial example:
203\include kalman_simple.cpp
204
205Complete constructor:
206*/
207class KalmanCh : public Kalman<chmat> {
208protected:
209    //! @{ \name Internal storage - needs initialize()
210    //! pre array (triangular matrix)
211    mat preA;
212    //! post array (triangular matrix)
213    mat postA;
214    //!@}
215public:
216    //! copy constructor
217    virtual KalmanCh* _copy() const {
218        KalmanCh* K = new KalmanCh;
219        K->set_parameters ( A, B, C, D, Q, R );
220        K->set_statistics ( est._mu(), est._R() );
221        K->validate();
222        return K;
223    }
224    //! set parameters for adapt from Kalman
225    void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0 );
226    //! initialize internal parametetrs
227    void initialize();
228
229    /*!\brief  Here dt = [yt;ut] of appropriate dimensions
230
231    The following equality hold::\f[
232    \left[\begin{array}{cc}
233    R^{0.5}\\
234    P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
235    & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
236    R_{y}^{0.5} & KA'\\
237    & P_{t+1|t}^{0.5}\\
238    \\\end{array}\right]\f]
239
240    Thus this object evaluates only predictors! Not filtering densities.
241    */
242    void bayes ( const vec &yt, const vec &cond = empty_vec );
243
244    void from_setting ( const Setting &set ) {
245        Kalman<chmat>::from_setting ( set );
246        validate();
247    }
248    void validate() {
249        Kalman<chmat>::validate();
250        initialize();
251    }
252};
253UIREGISTER ( KalmanCh );
254
255/*!
256\brief Extended Kalman Filter in full matrices
257
258An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
259*/
260class EKFfull : public KalmanFull {
261protected:
262    //! Internal Model f(x,u)
263    shared_ptr<diffbifn> pfxu;
264
265    //! Observation Model h(x,u)
266    shared_ptr<diffbifn> phxu;
267
268public:
269    //! Default constructor
270    EKFfull ();
271
272    //! Set nonlinear functions for mean values and covariance matrices.
273    void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0 );
274
275    //! Here dt = [yt;ut] of appropriate dimensions
276    void bayes ( const vec &yt, const vec &cond = empty_vec );
277    //! set estimates
278    void set_statistics ( const vec &mu0, const mat &P0 ) {
279        est.set_parameters ( mu0, P0 );
280    };
281    //! access function
282    const mat _R() {
283        return est._R().to_mat();
284    }
285    void from_setting ( const Setting &set ) {
286        BM::from_setting ( set );
287        shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
288        shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );
289
290        //statistics
291        int dim = IM->dimension();
292        vec mu0;
293        if ( !UI::get ( mu0, set, "mu0" ) )
294            mu0 = zeros ( dim );
295
296        mat P0;
297        vec dP0;
298        if ( UI::get ( dP0, set, "dP0" ) )
299            P0 = diag ( dP0 );
300        else if ( !UI::get ( P0, set, "P0" ) )
301            P0 = eye ( dim );
302
303        set_statistics ( mu0, P0 );
304
305        //parameters
306        vec dQ, dR;
307        UI::get ( dQ, set, "dQ", UI::compulsory );
308        UI::get ( dR, set, "dR", UI::compulsory );
309        set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );
310
311//                      pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
312//                      phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
313//
314//                      mat R0;
315//                      UI::get(R0, set, "R",UI::compulsory);
316//                      mat Q0;
317//                      UI::get(Q0, set, "Q",UI::compulsory);
318//
319//
320//                      mat P0; vec mu0;
321//                      UI::get(mu0, set, "mu0", UI::optional);
322//                      UI::get(P0, set,  "P0", UI::optional);
323//                      set_statistics(mu0,P0);
324//                      // Initial values
325//                      UI::get (yrv, set, "yrv", UI::optional);
326//                      UI::get (urv, set, "urv", UI::optional);
327//                      set_drv(concat(yrv,urv));
328//
329//                      // setup StateSpace
330//                      pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
331//                      phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
332//
333    }
334
335    void validate() {
336        KalmanFull::validate();
337
338        // check stats and IM and OM
339    }
340};
341UIREGISTER ( EKFfull );
342
343
344/*!
345\brief Extended Kalman Filter in Square root
346
347An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
348*/
349
350class EKFCh : public KalmanCh {
351protected:
352    //! Internal Model f(x,u)
353    shared_ptr<diffbifn> pfxu;
354
355    //! Observation Model h(x,u)
356    shared_ptr<diffbifn> phxu;
357public:
358    //! copy constructor duplicated - calls different set_parameters
359    EKFCh* _copy() const {
360        return new EKFCh(*this);
361    }
362    //! Set nonlinear functions for mean values and covariance matrices.
363    void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0 );
364
365    //! Here dt = [yt;ut] of appropriate dimensions
366    void bayes ( const vec &yt, const vec &cond = empty_vec );
367
368    void from_setting ( const Setting &set );
369
370    void validate() {};
371    // TODO dodelat void to_setting( Setting &set ) const;
372
373};
374
375UIREGISTER ( EKFCh );
376SHAREDPTR ( EKFCh );
377
378
379//////// INstance
380
381/*! \brief (Switching) Multiple Model
382The model runs several models in parallel and evaluates thier weights (fittness).
383
384The statistics of the resulting density are merged using (geometric?) combination.
385
386The next step is performed with the new statistics for all models.
387*/
388class MultiModel: public BM {
389protected:
390    //! List of models between which we switch
391    Array<EKFCh*> Models;
392    //! vector of model weights
393    vec w;
394    //! cache of model lls
395    vec _lls;
396    //! type of switching policy [1=maximum,2=...]
397    int policy;
398    //! internal statistics
399    enorm<chmat> est;
400public:
401    //! set internal parameters
402    void set_parameters ( Array<EKFCh*> A, int pol0 = 1 ) {
403        Models = A;//TODO: test if evalll is set
404        w.set_length ( A.length() );
405        _lls.set_length ( A.length() );
406        policy = pol0;
407
408        est.set_rv ( RV ( "MM", A ( 0 )->posterior().dimension(), 0 ) );
409        est.set_parameters ( A ( 0 )->posterior().mean(), A ( 0 )->posterior()._R() );
410    }
411    void bayes ( const vec &yt, const vec &cond = empty_vec ) {
412        int n = Models.length();
413        int i;
414        for ( i = 0; i < n; i++ ) {
415            Models ( i )->bayes ( yt );
416            _lls ( i ) = Models ( i )->_ll();
417        }
418        double mlls = max ( _lls );
419        w = exp ( _lls - mlls );
420        w /= sum ( w ); //normalization
421        //set statistics
422        switch ( policy ) {
423        case 1: {
424            int mi = max_index ( w );
425            const enorm<chmat> &st = Models ( mi )->posterior() ;
426            est.set_parameters ( st.mean(), st._R() );
427        }
428        break;
429        default:
430            bdm_error ( "unknown policy" );
431        }
432        // copy result to all models
433        for ( i = 0; i < n; i++ ) {
434            Models ( i )->set_statistics ( est.mean(), est._R() );
435        }
436    }
437    //! return correctly typed posterior (covariant return)
438    const enorm<chmat>& posterior() const {
439        return est;
440    }
441
442    void from_setting ( const Setting &set );
443
444};
445UIREGISTER ( MultiModel );
446SHAREDPTR ( MultiModel );
447
448//! conversion of outer ARX model (mlnorm) to state space model
449/*!
450The model is constructed as:
451\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
452For example, for:
453Using Frobenius form, see [].
454
455For easier use in the future, indices theta_in_A and theta_in_C are set. TODO - explain
456*/
457//template<class sq_T>
458class StateCanonical: public StateSpace<fsqmat> {
459protected:
460    //! remember connection from theta ->A
461    datalink_part th2A;
462    //! remember connection from theta ->C
463    datalink_part th2C;
464    //! remember connection from theta ->D
465    datalink_part th2D;
466    //!cached first row of A
467    vec A1row;
468    //!cached first row of C
469    vec C1row;
470    //!cached first row of D
471    vec D1row;
472
473public:
474    //! set up this object to match given mlnorm
475    void connect_mlnorm ( const mlnorm<fsqmat> &ml );
476
477    //! fast function to update parameters from ml - not checked for compatibility!!
478    void update_from ( const mlnorm<fsqmat> &ml );
479};
480/*!
481State-Space representation of multivariate autoregressive model.
482The original model:
483\f[ y_t = \theta [\ldots y_{t-k}, \ldots u_{t-l}, \ldots z_{t-m}]' + \Sigma^{-1/2} e_t \f]
484where \f$ k,l,m \f$ are maximum delayes of corresponding variables in the regressor.
485
486The transformed state is:
487\f[ x_t = [y_{t} \ldots y_{t-k-1}, u_{t} \ldots u_{t-l-1}, z_{t} \ldots z_{t-m-1}]\f]
488
489The state accumulates all delayed values starting from time \f$ t \f$ .
490
491
492*/
493class StateFromARX: public StateSpace<chmat> {
494protected:
495    //! remember connection from theta ->A
496    datalink_part th2A;
497    //! remember connection from theta ->B
498    datalink_part th2B;
499    //!function adds n diagonal elements from given starting point r,c
500    void diagonal_part ( mat &A, int r, int c, int n ) {
501        for ( int i = 0; i < n; i++ ) {
502            A ( r, c ) = 1.0;
503            r++;
504            c++;
505        }
506    };
507    //! similar to ARX.have_constant
508    bool have_constant;
509public:
510    //! set up this object to match given mlnorm
511    //! Note that state-space and common mpdf use different meaning of \f$ _t \f$ in \f$ u_t \f$.
512    //!While mlnorm typically assumes that \f$ u_t \rightarrow y_t \f$ in state space it is \f$ u_{t-1} \rightarrow y_t \f$
513    //! For consequences in notation of internal variable xt see arx2statespace_notes.lyx.
514    void connect_mlnorm ( const mlnorm<chmat> &ml, RV &xrv, RV &urv );
515
516    //! fast function to update parameters from ml - not checked for compatibility!!
517    void update_from ( const mlnorm<chmat> &ml );
518
519    //! access function
520    bool _have_constant() const {
521        return have_constant;
522    }
523};
524
525/////////// INSTANTIATION
526
527template<class sq_T>
528void StateSpace<sq_T>::set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 ) {
529
530    A = A0;
531    B = B0;
532    C = C0;
533    D = D0;
534    R = R0;
535    Q = Q0;
536    validate();
537}
538
539template<class sq_T>
540void StateSpace<sq_T>::validate() {
541    bdm_assert ( A.cols() == A.rows(), "KalmanFull: A is not square" );
542    bdm_assert ( B.rows() == A.rows(), "KalmanFull: B is not compatible" );
543    bdm_assert ( C.cols() == A.rows(), "KalmanFull: C is not compatible" );
544    bdm_assert ( ( D.rows() == C.rows() ) && ( D.cols() == B.cols() ), "KalmanFull: D is not compatible" );
545    bdm_assert ( ( Q.cols() == A.rows() ) && ( Q.rows() == A.rows() ), "KalmanFull: Q is not compatible" );
546    bdm_assert ( ( R.cols() == C.rows() ) && ( R.rows() == C.rows() ), "KalmanFull: R is not compatible" );
547}
548
549}
550#endif // KF_H
551
Note: See TracBrowser for help on using the browser.