root/library/bdm/estim/kalman.h @ 625

Revision 625, 11.9 kB (checked in by smidl, 15 years ago)

ARX re-designed

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
56                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
57                void validate();
58                //! not virtual in this case
59                void from_setting (const Setting &set) {
60                        UI::get (A, set, "A", UI::compulsory);
61                        UI::get (B, set, "B", UI::compulsory);
62                        UI::get (C, set, "C", UI::compulsory);
63                        UI::get (D, set, "D", UI::compulsory);
64                        mat Qtm, Rtm;
65                        if(!UI::get(Qtm, set, "Q", UI::optional)){
66                                vec dq;
67                                UI::get(dq, set, "dQ", UI::compulsory);
68                                Qtm=diag(dq);
69                        }
70                        if(!UI::get(Rtm, set, "R", UI::optional)){
71                                vec dr;
72                                UI::get(dr, set, "dQ", UI::compulsory);
73                                Rtm=diag(dr);
74                        }
75                        R=Rtm; // automatic conversion
76                        Q=Qtm; 
77                       
78                        validate();
79                }               
80                //! access function
81                int _dimx(){return dimx;}
82                //! access function
83                int _dimy(){return dimy;}
84                //! access function
85                int _dimu(){return dimu;}
86                //! access function
87                const mat& _A() const {return A;}
88                //! access function
89                const mat& _B()const {return B;}
90                //! access function
91                const mat& _C()const {return C;}
92                //! access function
93                const mat& _D()const {return D;}
94                //! access function
95                const sq_T& _Q()const {return Q;}
96                //! access function
97                const sq_T& _R()const {return R;}
98};
99
100//! Common abstract base for Kalman filters
101template<class sq_T>
102class Kalman: public BM, public StateSpace<sq_T>
103{
104        protected:
105                //! id of output
106                RV yrv;
107                //! id of input
108                RV urv;
109                //! Kalman gain
110                mat  _K;
111                //!posterior
112                shared_ptr<enorm<sq_T> > est;
113                //!marginal on data f(y|y)
114                enorm<sq_T>  fy;
115        public:
116                Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
117                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
118                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
119                //! posterior
120                const enorm<sq_T>& posterior() const {return *est.get();}
121                //! shared posterior
122                shared_ptr<epdf> shared_posterior() {return est;}
123                //! load basic elements of Kalman from structure
124                void from_setting (const Setting &set) {
125                        StateSpace<sq_T>::from_setting(set);
126                                               
127                        mat P0; vec mu0;
128                        UI::get(mu0, set, "mu0", UI::optional);
129                        UI::get(P0, set,  "P0", UI::optional);
130                        set_statistics(mu0,P0);
131                        // Initial values
132                        UI::get (yrv, set, "yrv", UI::optional);
133                        UI::get (urv, set, "urv", UI::optional);
134                        set_drv(concat(yrv,urv));
135                       
136                        validate();
137                }
138                void validate() {
139                        StateSpace<sq_T>::validate();
140                        bdm_assert(est->dimension(), "Statistics and model parameters mismatch");
141                }
142};
143/*!
144* \brief Basic Kalman filter with full matrices
145*/
146
147class KalmanFull : public Kalman<fsqmat>
148{
149        public:
150                //! For EKFfull;
151                KalmanFull() :Kalman<fsqmat>(){};
152                //! Here dt = [yt;ut] of appropriate dimensions
153                void bayes (const vec &dt);
154};
155UIREGISTER(KalmanFull);
156
157
158/*! \brief Kalman filter in square root form
159
160Trivial example:
161\include kalman_simple.cpp
162
163Complete constructor:
164*/
165class KalmanCh : public Kalman<chmat>
166{
167        protected:
168                //! @{ \name Internal storage - needs initialize()
169                //! pre array (triangular matrix)
170                mat preA;
171                //! post array (triangular matrix)
172                mat postA;
173                //!@}
174        public:
175                //! copy constructor
176                BM* _copy_() const {
177                        KalmanCh* K = new KalmanCh;
178                        K->set_parameters (A, B, C, D, Q, R);
179                        K->set_statistics (est->_mu(), est->_R());
180                        return K;
181                }
182                //! set parameters for adapt from Kalman
183                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
184                //! initialize internal parametetrs
185                void initialize();
186
187                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
188
189                The following equality hold::\f[
190                \left[\begin{array}{cc}
191                R^{0.5}\\
192                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
193                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
194                R_{y}^{0.5} & KA'\\
195                & P_{t+1|t}^{0.5}\\
196                \\\end{array}\right]\f]
197
198                Thus this object evaluates only predictors! Not filtering densities.
199                */
200                void bayes (const vec &dt);
201               
202                void from_setting(const Setting &set){
203                        Kalman<chmat>::from_setting(set);
204                        initialize();
205                }
206};
207UIREGISTER(KalmanCh);
208
209/*!
210\brief Extended Kalman Filter in full matrices
211
212An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
213*/
214class EKFfull : public KalmanFull
215{
216        protected:
217                //! Internal Model f(x,u)
218                shared_ptr<diffbifn> pfxu;
219
220                //! Observation Model h(x,u)
221                shared_ptr<diffbifn> phxu;
222
223        public:
224                //! Default constructor
225                EKFfull ();
226
227                //! Set nonlinear functions for mean values and covariance matrices.
228                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
229
230                //! Here dt = [yt;ut] of appropriate dimensions
231                void bayes (const vec &dt);
232                //! set estimates
233                void set_statistics (const vec &mu0, const mat &P0) {
234                        est->set_parameters (mu0, P0);
235                };
236                const mat _R() {
237                        return est->_R().to_mat();
238                }
239};
240UIREGISTER(EKFfull);
241
242
243/*!
244\brief Extended Kalman Filter in Square root
245
246An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
247*/
248
249class EKFCh : public KalmanCh
250{
251        protected:
252                //! Internal Model f(x,u)
253                shared_ptr<diffbifn> pfxu;
254
255                //! Observation Model h(x,u)
256                shared_ptr<diffbifn> phxu;
257        public:
258                //! copy constructor duplicated - calls different set_parameters
259                BM* _copy_() const {
260                        EKFCh* E = new EKFCh;
261                        E->set_parameters (pfxu, phxu, Q, R);
262                        E->set_statistics (est->_mu(), est->_R());
263                        return E;
264                }
265                //! Set nonlinear functions for mean values and covariance matrices.
266                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
267
268                //! Here dt = [yt;ut] of appropriate dimensions
269                void bayes (const vec &dt);
270
271                void from_setting (const Setting &set);
272
273                // TODO dodelat void to_setting( Setting &set ) const;
274
275};
276
277UIREGISTER (EKFCh);
278SHAREDPTR (EKFCh);
279
280
281//////// INstance
282
283/*! \brief (Switching) Multiple Model
284The model runs several models in parallel and evaluates thier weights (fittness).
285
286The statistics of the resulting density are merged using (geometric?) combination.
287
288The next step is performed with the new statistics for all models.
289*/
290class MultiModel: public BM
291{
292        protected:
293                //! List of models between which we switch
294                Array<EKFCh*> Models;
295                //! vector of model weights
296                vec w;
297                //! cache of model lls
298                vec _lls;
299                //! type of switching policy [1=maximum,2=...]
300                int policy;
301                //! internal statistics
302                enorm<chmat> est;
303        public:
304                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
305                        Models = A;//TODO: test if evalll is set
306                        w.set_length (A.length());
307                        _lls.set_length (A.length());
308                        policy = pol0;
309
310                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
311                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
312                }
313                void bayes (const vec &dt) {
314                        int n = Models.length();
315                        int i;
316                        for (i = 0; i < n; i++) {
317                                Models (i)->bayes (dt);
318                                _lls (i) = Models (i)->_ll();
319                        }
320                        double mlls = max (_lls);
321                        w = exp (_lls - mlls);
322                        w /= sum (w);   //normalization
323                        //set statistics
324                        switch (policy) {
325                                case 1: {
326                                        int mi = max_index (w);
327                                        const enorm<chmat> &st = Models (mi)->posterior() ;
328                                        est.set_parameters (st.mean(), st._R());
329                                }
330                                break;
331                                default:
332                                        bdm_error ("unknown policy");
333                        }
334                        // copy result to all models
335                        for (i = 0; i < n; i++) {
336                                Models (i)->set_statistics (est.mean(), est._R());
337                        }
338                }
339                //! posterior density
340                const enorm<chmat>& posterior() const {
341                        return est;
342                }
343
344                void from_setting (const Setting &set);
345
346                // TODO dodelat void to_setting( Setting &set ) const;
347
348};
349
350UIREGISTER (MultiModel);
351SHAREDPTR (MultiModel);
352
353//! conversion of outer ARX model (mlnorm) to state space model
354/*!
355The model is constructed as:
356\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
357For example, for:
358Using Frobenius form, see [].
359
360For easier use in the future, indeces theta_in_A and theta_in_C are set. TODO - explain
361*/
362//template<class sq_T>
363class StateCanonical: public StateSpace<fsqmat>{
364        protected:
365                //! remember connection from theta ->A
366                datalink_part th2A;
367                //! remember connection from theta ->C
368                datalink_part th2C;
369                //! remember connection from theta ->D
370                datalink_part th2D;
371                //!cached first row of A
372                vec A1row;
373                //!cached first row of C
374                vec C1row;
375                //!cached first row of D
376                vec D1row;
377               
378        public:
379                //! set up this object to match given mlnorm
380                void connect_mlnorm(const mlnorm<fsqmat> &ml){
381                        //get ids of yrv                               
382                        const RV &yrv = ml._rv();
383                        //need to determine u_t - it is all in _rvc that is not in ml._rv()
384                        RV rgr0 = ml._rvc().remove_time();
385                        RV urv = rgr0.subt(yrv); 
386                       
387                        //We can do only 1d now... :(
388                        bdm_assert(yrv._dsize()==1, "Only for SISO so far..." );
389
390                        // create names for
391                        RV xrv; //empty
392                        RV Crv; //empty
393                        int td=ml._rvc().mint();
394                        // assuming strictly proper function!!!
395                        for (int t=-1;t>=td;t--){
396                                xrv.add(yrv.copy_t(t));
397                                Crv.add(urv.copy_t(t));
398                        }
399                       
400                        this->dimx = xrv._dsize();
401                        this->dimy = yrv._dsize();
402                        this->dimu = urv._dsize();
403                       
404                        // get mapp
405                        th2A.set_connection(xrv, ml._rvc());
406                        th2C.set_connection(Crv, ml._rvc());
407                        th2D.set_connection(urv, ml._rvc());
408
409                        //set matrix sizes
410                        this->A=zeros(dimx,dimx);
411                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
412                                this->B=zeros(dimx,1);
413                                this->B(0) = 1.0;
414                                this->C=zeros(1,dimx);
415                                this->D=zeros(1,urv._dsize());
416                                this->Q = zeros(dimx,dimx);
417                        // R is set by update
418                       
419                        //set cache
420                        this->A1row = zeros(xrv._dsize());
421                        this->C1row = zeros(xrv._dsize());
422                        this->D1row = zeros(urv._dsize());
423                       
424                        update_from(ml);
425                        validate();
426                };
427                //! fast function to update parameters from ml - not checked for compatibility!!
428                void update_from(const mlnorm<fsqmat> &ml){
429                       
430                        vec theta = ml._A().get_row(0); // this
431                       
432                        th2A.filldown(theta,A1row);
433                        th2C.filldown(theta,C1row);
434                        th2D.filldown(theta,D1row);
435
436                        R = ml._R();
437
438                        A.set_row(0,A1row);
439                        C.set_row(0,C1row+D1row*A1row);
440                        D.set_row(0,D1row);
441                       
442                }
443};
444
445/////////// INSTANTIATION
446
447template<class sq_T>
448void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
449{
450        dimx = A0.rows();
451        dimu = B0.cols();
452        dimy = C0.rows();
453
454        A = A0;
455        B = B0;
456        C = C0;
457        D = D0;
458        R = R0;
459        Q = Q0;
460        validate();
461}
462
463template<class sq_T>
464void StateSpace<sq_T>::validate(){
465        bdm_assert (A.cols() == dimx, "KalmanFull: A is not square");
466        bdm_assert (B.rows() == dimx, "KalmanFull: B is not compatible");
467        bdm_assert (C.cols() == dimx, "KalmanFull: C is not square");
468        bdm_assert ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
469        bdm_assert ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
470        bdm_assert ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
471}
472
473}
474#endif // KF_H
475
Note: See TracBrowser for help on using the browser.