root/library/bdm/estim/kalman.h @ 586

Revision 586, 11.1 kB (checked in by smidl, 15 years ago)

redesign of ctrl LQ control for arx

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef KF_H
14#define KF_H
15
16
17#include "../math/functions.h"
18#include "../stat/exp_family.h"
19#include "../math/chmat.h"
20#include "../base/user_info.h"
21
22namespace bdm
23{
24
25/*!
26 * \brief Basic elements of linear state-space model
27
28Parameter evolution model:\f[ x_t = A x_{t-1} + B u_t + Q^{1/2} e_t \f]
29Observation model: \f[ y_t = C x_{t-1} + C u_t + Q^{1/2} w_t. \f]
30Where $e_t$ and $w_t$ are independent vectors Normal(0,1)-distributed disturbances.
31 */
32template<class sq_T>
33class StateSpace
34{
35        protected:
36                //! cache of rv.count()
37                int dimx;
38                //! cache of rvy.count()
39                int dimy;
40                //! cache of rvu.count()
41                int dimu;
42                //! Matrix A
43                mat A;
44                //! Matrix B
45                mat B;
46                //! Matrix C
47                mat C;
48                //! Matrix D
49                mat D;
50                //! Matrix Q in square-root form
51                sq_T Q;
52                //! Matrix R in square-root form
53                sq_T R;
54        public:
55                StateSpace() : dimx (0), dimy (0), dimu (0), A(), B(), C(), D(), Q(), R() {}
56                void set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0);
57                void validate();
58                //! not virtual in this case
59                void from_setting (const Setting &set) {
60                        UI::get (A, set, "A", UI::compulsory);
61                        UI::get (B, set, "B", UI::compulsory);
62                        UI::get (C, set, "C", UI::compulsory);
63                        UI::get (D, set, "D", UI::compulsory);
64                        mat Qtm, Rtm;
65                        if(!UI::get(Qtm, set, "Q", UI::optional)){
66                                vec dq;
67                                UI::get(dq, set, "dQ", UI::compulsory);
68                                Qtm=diag(dq);
69                        }
70                        if(!UI::get(Rtm, set, "R", UI::optional)){
71                                vec dr;
72                                UI::get(dr, set, "dQ", UI::compulsory);
73                                Rtm=diag(dr);
74                        }
75                        R=Rtm; // automatic conversion
76                        Q=Qtm; 
77                       
78                        validate();
79                }               
80                //! access function
81                int _dimx(){return dimx;}
82                //! access function
83                int _dimy(){return dimy;}
84                //! access function
85                int _dimu(){return dimu;}
86                //! access function
87                mat _A(){return A;}
88                //! access function
89                mat _B(){return B;}
90                //! access function
91                mat _C(){return C;}
92                //! access function
93                mat _D(){return D;}
94                //! access function
95                sq_T _Q(){return Q;}
96                //! access function
97                sq_T _R(){return R;}
98};
99
100//! Common abstract base for Kalman filters
101template<class sq_T>
102class Kalman: public BM, public StateSpace<sq_T>
103{
104        protected:
105                //! id of output
106                RV yrv;
107                //! id of input
108                RV urv;
109                //! Kalman gain
110                mat  _K;
111                //!posterior
112                shared_ptr<enorm<sq_T> > est;
113                //!marginal on data f(y|y)
114                enorm<sq_T>  fy;
115        public:
116                Kalman() : BM(), StateSpace<sq_T>(), yrv(),urv(), _K(),  est(new enorm<sq_T>){}
117                void set_statistics (const vec &mu0, const mat &P0) {est->set_parameters (mu0, P0); };
118                void set_statistics (const vec &mu0, const sq_T &P0) {est->set_parameters (mu0, P0); };
119                //! posterior
120                const enorm<sq_T>& posterior() const {return *est.get();}
121                //! shared posterior
122                shared_ptr<epdf> shared_posterior() {return est;}
123                //! load basic elements of Kalman from structure
124                void from_setting (const Setting &set) {
125                        StateSpace<sq_T>::from_setting(set);
126                                               
127                        mat P0; vec mu0;
128                        UI::get(mu0, set, "mu0", UI::optional);
129                        UI::get(P0, set,  "P0", UI::optional);
130                        set_statistics(mu0,P0);
131                        // Initial values
132                        UI::get (yrv, set, "yrv", UI::optional);
133                        UI::get (urv, set, "urv", UI::optional);
134                        set_drv(concat(yrv,urv));
135                       
136                        validate();
137                }
138                void validate() {
139                        StateSpace<sq_T>::validate();
140                        bdm_assert_debug(est->dimension(), "Statistics and model parameters mismatch");
141                }
142};
143/*!
144* \brief Basic Kalman filter with full matrices
145*/
146
147class KalmanFull : public Kalman<fsqmat>
148{
149        public:
150                //! For EKFfull;
151                KalmanFull() :Kalman<fsqmat>(){};
152                //! Here dt = [yt;ut] of appropriate dimensions
153                void bayes (const vec &dt);
154};
155UIREGISTER(KalmanFull);
156
157
158/*! \brief Kalman filter in square root form
159
160Trivial example:
161\include kalman_simple.cpp
162
163Complete constructor:
164*/
165class KalmanCh : public Kalman<chmat>
166{
167        protected:
168                //! @{ \name Internal storage - needs initialize()
169                //! pre array (triangular matrix)
170                mat preA;
171                //! post array (triangular matrix)
172                mat postA;
173                //!@}
174        public:
175                //! copy constructor
176                BM* _copy_() const {
177                        KalmanCh* K = new KalmanCh;
178                        K->set_parameters (A, B, C, D, Q, R);
179                        K->set_statistics (est->_mu(), est->_R());
180                        return K;
181                }
182                //! set parameters for adapt from Kalman
183                void set_parameters (const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0);
184                //! initialize internal parametetrs
185                void initialize();
186
187                /*!\brief  Here dt = [yt;ut] of appropriate dimensions
188
189                The following equality hold::\f[
190                \left[\begin{array}{cc}
191                R^{0.5}\\
192                P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
193                & Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
194                R_{y}^{0.5} & KA'\\
195                & P_{t+1|t}^{0.5}\\
196                \\\end{array}\right]\f]
197
198                Thus this object evaluates only predictors! Not filtering densities.
199                */
200                void bayes (const vec &dt);
201               
202                void from_setting(const Setting &set){
203                        Kalman<chmat>::from_setting(set);
204                        initialize();
205                }
206};
207UIREGISTER(KalmanCh);
208
209/*!
210\brief Extended Kalman Filter in full matrices
211
212An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
213*/
214class EKFfull : public KalmanFull
215{
216        protected:
217                //! Internal Model f(x,u)
218                shared_ptr<diffbifn> pfxu;
219
220                //! Observation Model h(x,u)
221                shared_ptr<diffbifn> phxu;
222
223        public:
224                //! Default constructor
225                EKFfull ();
226
227                //! Set nonlinear functions for mean values and covariance matrices.
228                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0);
229
230                //! Here dt = [yt;ut] of appropriate dimensions
231                void bayes (const vec &dt);
232                //! set estimates
233                void set_statistics (const vec &mu0, const mat &P0) {
234                        est->set_parameters (mu0, P0);
235                };
236                const mat _R() {
237                        return est->_R().to_mat();
238                }
239};
240UIREGISTER(EKFfull);
241
242
243/*!
244\brief Extended Kalman Filter in Square root
245
246An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
247*/
248
249class EKFCh : public KalmanCh
250{
251        protected:
252                //! Internal Model f(x,u)
253                shared_ptr<diffbifn> pfxu;
254
255                //! Observation Model h(x,u)
256                shared_ptr<diffbifn> phxu;
257        public:
258                //! copy constructor duplicated - calls different set_parameters
259                BM* _copy_() const {
260                        EKFCh* E = new EKFCh;
261                        E->set_parameters (pfxu, phxu, Q, R);
262                        E->set_statistics (est->_mu(), est->_R());
263                        return E;
264                }
265                //! Set nonlinear functions for mean values and covariance matrices.
266                void set_parameters (const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0);
267
268                //! Here dt = [yt;ut] of appropriate dimensions
269                void bayes (const vec &dt);
270
271                void from_setting (const Setting &set);
272
273                // TODO dodelat void to_setting( Setting &set ) const;
274
275};
276
277UIREGISTER (EKFCh);
278SHAREDPTR (EKFCh);
279
280
281//////// INstance
282
283/*! \brief (Switching) Multiple Model
284The model runs several models in parallel and evaluates thier weights (fittness).
285
286The statistics of the resulting density are merged using (geometric?) combination.
287
288The next step is performed with the new statistics for all models.
289*/
290class MultiModel: public BM
291{
292        protected:
293                //! List of models between which we switch
294                Array<EKFCh*> Models;
295                //! vector of model weights
296                vec w;
297                //! cache of model lls
298                vec _lls;
299                //! type of switching policy [1=maximum,2=...]
300                int policy;
301                //! internal statistics
302                enorm<chmat> est;
303        public:
304                void set_parameters (Array<EKFCh*> A, int pol0 = 1) {
305                        Models = A;//TODO: test if evalll is set
306                        w.set_length (A.length());
307                        _lls.set_length (A.length());
308                        policy = pol0;
309
310                        est.set_rv (RV ("MM", A (0)->posterior().dimension(), 0));
311                        est.set_parameters (A (0)->posterior().mean(), A (0)->posterior()._R());
312                }
313                void bayes (const vec &dt) {
314                        int n = Models.length();
315                        int i;
316                        for (i = 0; i < n; i++) {
317                                Models (i)->bayes (dt);
318                                _lls (i) = Models (i)->_ll();
319                        }
320                        double mlls = max (_lls);
321                        w = exp (_lls - mlls);
322                        w /= sum (w);   //normalization
323                        //set statistics
324                        switch (policy) {
325                                case 1: {
326                                        int mi = max_index (w);
327                                        const enorm<chmat> &st = Models (mi)->posterior() ;
328                                        est.set_parameters (st.mean(), st._R());
329                                }
330                                break;
331                                default:
332                                        bdm_error ("unknown policy");
333                        }
334                        // copy result to all models
335                        for (i = 0; i < n; i++) {
336                                Models (i)->set_statistics (est.mean(), est._R());
337                        }
338                }
339                //! posterior density
340                const enorm<chmat>& posterior() const {
341                        return est;
342                }
343
344                void from_setting (const Setting &set);
345
346                // TODO dodelat void to_setting( Setting &set ) const;
347
348};
349
350UIREGISTER (MultiModel);
351SHAREDPTR (MultiModel);
352
353//! conversion of outer ARX model (mlnorm) to state space model
354/*!
355The model is constructed as:
356\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
357For example, for:
358\f[ y_t = a y_{t-1} + b u_{t-1}\f]
359the state \f$ x_t = [y_{t-1}, u_{t-1}] \f$
360*/
361template<class sq_T>
362shared_ptr<StateSpace<fsqmat> > to_state_space(const mlnorm<sq_T > &ml, ivec &theta_in_A, ivec &theta_in_C){
363        //get ids of yrv
364                       
365                        ivec yids= unique(ml._rv()._ids()); //yrv is now stored in _yrv
366                        ivec dids= unique(ml._rvc()._ids()); //yrv is now stored in _yrv
367                        ivec uids=unique_complement(dids, yids);
368       
369                        RV yrv_uni = RV(yids,zeros_i(yids.length()));
370                        RV urv_uni = RV(uids,zeros_i(yids.length()));
371                        //We can do only 1d now... :(
372                        bdm_assert_debug(yrv_uni._dsize()==urv_uni._dsize()==1, "Only for SISO so far..." );
373                                       
374                        RV xrv; //empty
375                        RV Crv; //empty
376                        int td=ml._rvc().mintd();
377                        for (int t=-1;t>=td;t--){
378                                xrv.add(yrv_uni);
379                                Crv.add(urv_uni);
380                        }
381                       
382                        int dimx = xrv._dsize();
383                       
384                        theta_in_A = ml._rv().dataind(xrv);
385                        theta_in_C = ml._rvc().dataind(xrv);
386                        // some chcek of corretness
387                       
388                        vec A1row = zeros(xrv._dsize());
389                        vec C1row = zeros(xrv._dsize());
390                        vec theta = ml._A().get_row(0); // this
391                        set_subvector( A1row, theta_in_A, theta);
392                        set_subvector( C1row, theta_in_C, theta);
393
394                        StateSpace<fsqmat> stsp=new StateSpace<fsqmat>();
395                        mat A=zeros(dimx,dimx);
396                        A.set_row(0,A1row);
397                        for (int j=1; j<dimx; j++){A(j,j-1)=1.0;} // off diagonal
398                        mat B=zeros(dimx,1);
399                        B(0) = 1.0;
400                        mat C=zeros(1,dimx);
401                        C.set_row(0,C1row);
402                        stsp.set_parameters(A,B,C,zeros(1,1), zeros(dimx,dimx), ml._R());
403                        return stsp;
404}
405
406/////////// INSTANTIATION
407
408template<class sq_T>
409void StateSpace<sq_T>::set_parameters (const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0)
410{
411        dimx = A0.rows();
412        dimu = B0.cols();
413        dimy = C0.rows();
414
415        A = A0;
416        B = B0;
417        C = C0;
418        D = D0;
419        R = R0;
420        Q = Q0;
421        validate();
422}
423
424template<class sq_T>
425void StateSpace<sq_T>::validate(){
426        bdm_assert_debug (A.cols() == dimx, "KalmanFull: A is not square");
427        bdm_assert_debug (B.rows() == dimx, "KalmanFull: B is not compatible");
428        bdm_assert_debug (C.cols() == dimx, "KalmanFull: C is not square");
429        bdm_assert_debug ( (D.rows() == dimy) || (D.cols() == dimu), "KalmanFull: D is not compatible");
430        bdm_assert_debug ( (Q.cols() == dimx) || (Q.rows() == dimx), "KalmanFull: Q is not compatible");
431        bdm_assert_debug ( (R.cols() == dimy) || (R.rows() == dimy), "KalmanFull: R is not compatible");
432}
433
434}
435#endif // KF_H
436
Note: See TracBrowser for help on using the browser.