root/library/bdm/estim/particles.h @ 951

Revision 951, 18.3 kB (checked in by smidl, 14 years ago)

Correction to particle filters

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticle : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29        //! pdf with for transitional par
30        shared_ptr<pdf> par; // pdf for non-linear part
31        //! link from this to bm
32        shared_ptr<datalink_part> cond2bm;
33        //! link from cond to par
34        shared_ptr<datalink_part> cond2par;
35        //! link from emp 2 par
36        shared_ptr<datalink_part> emp2bm;
37        //! link from emp 2 par
38        shared_ptr<datalink_part> emp2par;
39       
40        //! custom posterior - product
41        class eprod_2:public eprod_base {
42                protected:
43                MarginalizedParticle &mp;
44                public:
45                eprod_2(MarginalizedParticle &m):mp(m){}
46                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
47                const int no_factors() const {return 2;}
48        } est;
49       
50        public:
51                MarginalizedParticle():est(*this){};
52                MarginalizedParticle(const MarginalizedParticle &m2):est(*this){
53                        bm = m2.bm->_copy();
54                        est_emp = m2.est_emp;
55                        par = m2.par;
56                        validate();
57                };
58                BM* _copy() const{return new MarginalizedParticle(*this);};
59                void bayes(const vec &dt, const vec &cond){
60                        vec par_cond(par->dimensionc());
61                        cond2par->filldown(cond,par_cond); // copy ut
62                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
63                       
64                        //sample new particle
65                        est_emp.set_point(par->samplecond(par_cond));
66                        //if (evalll)
67                        vec bm_cond(bm->dimensionc());
68                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
69                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
70                        bm->bayes(dt, bm_cond);
71                        ll=bm->_ll();
72                }
73                const eprod_2& posterior() const {return est;}
74       
75                void set_prior(const epdf *pdf0){
76                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
77                        if (ep){ // full prior
78                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
79                                bm->set_prior(ep->factor(0));
80                                est_emp.set_point(ep->factor(1)->sample());
81                        } else {
82                                // assume prior is only for emp;
83                                est_emp.set_point(pdf0->sample());
84                        }
85                }
86
87                /*! parse structure
88                \code
89                class = "BootstrapParticle";
90                parameter_pdf = {class = 'epdf_offspring', ...};
91                observation_pdf = {class = 'epdf_offspring',...};
92                \endcode
93                If rvs are set, then it checks for compatibility.
94                */
95                void from_setting(const Setting &set){
96                        BM::from_setting ( set );
97                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
98                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
99                }
100               
101                void to_setting(const Setting &set){
102                        if (BM::log_level[logfull]){
103                        }
104                }
105                void validate(){
106                        if (est_emp.point.length()!=par->dimension())
107                                est_emp.set_point(zeros(par->dimension()));
108                        est.validate();
109                       
110                        yrv = bm->_yrv();
111                        dimy = bm->dimensiony();
112                        set_rv( concat(bm->_rv(), par->_rv()));
113                        set_dim( par->dimension()+bm->dimension());
114
115                        rvc = par->_rvc();
116                        rvc.add(bm->_rvc());
117                        rvc=rvc.subt(par->_rv());
118                        rvc=rvc.subt(par->_rv().copy_t(-1));
119                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
120                       
121                        cond2bm=new datalink_part;
122                        cond2par=new datalink_part;
123                        emp2bm  =new datalink_part;
124                        emp2par =new datalink_part;
125                        cond2bm->set_connection(bm->_rvc(), rvc);
126                        cond2par->set_connection(par->_rvc(), rvc);
127                        emp2bm->set_connection(bm->_rvc(), par->_rv());
128                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
129                       
130                        dimc = rvc._dsize();
131                };
132};
133UIREGISTER(MarginalizedParticle);
134
135//! class used in PF
136class BootstrapParticle : public BM{
137        dirac est;
138        shared_ptr<pdf> par;
139        shared_ptr<pdf> obs;
140        shared_ptr<datalink_part> cond2par;
141        shared_ptr<datalink_part> cond2obs;
142        shared_ptr<datalink_part> xt2obs;
143        shared_ptr<datalink_part> xtm2par;
144        public:
145                BM* _copy() const{return new BootstrapParticle(*this);};
146                void bayes(const vec &dt, const vec &cond){
147                        vec par_cond(par->dimensionc());
148                        cond2par->filldown(cond,par_cond); // copy ut
149                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
150                       
151                        //sample new particle
152                        est.set_point(par->samplecond(par_cond));
153                        //if (evalll)
154                        vec obs_cond(obs->dimensionc());
155                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
156                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
157                        ll=obs->evallogcond(dt,obs_cond);
158                }
159                const dirac& posterior() const {return est;}
160               
161                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
162               
163                /*! parse structure
164                \code
165                class = "BootstrapParticle";
166                parameter_pdf = {class = 'epdf_offspring', ...};
167                observation_pdf = {class = 'epdf_offspring',...};
168                \endcode
169                If rvs are set, then it checks for compatibility.
170                */
171                void from_setting(const Setting &set){
172                        BM::from_setting ( set );
173                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
174                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
175                }
176                void validate(){
177                        yrv = obs->_rv();
178                        dimy = obs->dimension();
179                        set_rv( par->_rv());
180                        set_dim( par->dimension());
181                       
182                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
183                        rvc.add(obs->_rvc()); //
184                       
185                        cond2obs=new datalink_part;
186                        cond2par=new datalink_part;
187                        xt2obs  =new datalink_part;
188                        xtm2par =new datalink_part;
189                        cond2obs->set_connection(obs->_rvc(), rvc);
190                        cond2par->set_connection(par->_rvc(), rvc);
191                        xt2obs->set_connection(obs->_rvc(), _rv());
192                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
193                       
194                        dimc = rvc._dsize();
195                };
196};
197UIREGISTER(BootstrapParticle);
198
199
200/*!
201* \brief Trivial particle filter with proposal density equal to parameter evolution model.
202
203Posterior density is represented by a weighted empirical density (\c eEmp ).
204*/
205
206class PF : public BM {
207        //! \var log_level_enums weights
208        //! all weightes will be logged
209
210        //! \var log_level_enums menas
211        //! means of particles will be logged
212        LOG_LEVEL(PF,logweights,logmeans,logvars);
213       
214        class pf_mix: public emix_base{
215                Array<BM*> &bms;
216                public:
217                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
218                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
219                        int no_coms() const {return bms.length();}
220        };
221protected:
222        //!number of particles;
223        int n;
224        //!posterior density
225        pf_mix est;
226        //! weights;
227        vec w;
228        //! particles
229        Array<BM*> particles;
230        //! internal structure storing loglikelihood of predictions
231        vec lls;
232
233        //! which resampling method will be used
234        RESAMPLING_METHOD resmethod;
235        //! resampling threshold; in this case its meaning is minimum ratio of active particles
236        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
237        double res_threshold;
238
239        //! \name Options
240        //!@{
241        //!@}
242
243public:
244        //! \name Constructors
245        //!@{
246        PF ( ) : est(w,particles) { };
247
248        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
249                n = n0;
250                res_threshold = res_th0;
251                resmethod = rm;
252        };
253        void set_model ( const BM *particle0, const epdf *prior) {
254                if (n>0){
255                        particles.set_length(n);
256                        for (int i=0; i<n;i++){
257                                particles(i) = particle0->_copy();
258                                particles(i)->set_prior(prior);
259                        }
260                }
261                // set values for posterior
262                est.set_rv ( particle0->posterior()._rv() );
263        };
264        void set_statistics ( const vec w0, const epdf &epdf0 ) {
265                //est.set_statistics ( w0, epdf0 );
266        };
267/*      void set_statistics ( const eEmp &epdf0 ) {
268                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
269                est = epdf0;
270        };*/
271        //!@}
272
273        //! bayes compute weights of the
274        virtual void bayes_weights();
275        //! important part of particle filtering - decide if it is time to perform resampling
276        virtual bool do_resampling() {
277                double eff = 1.0 / ( w * w );
278                return eff < ( res_threshold*n );
279        }
280        void bayes ( const vec &yt, const vec &cond );
281        //!access function
282        vec& _lls() {
283                return lls;
284        }
285        //!access function
286        RESAMPLING_METHOD _resmethod() const {
287                return resmethod;
288        }
289        //! return correctly typed posterior (covariant return)
290        const pf_mix& posterior() const {
291                return est;
292        }
293
294        /*! configuration structure for basic PF
295        \code
296        parameter_pdf   = pdf_class;         // parameter evolution pdf
297        observation_pdf = pdf_class;         // observation pdf
298        prior           = epdf_class;         // prior probability density
299        --- optional ---
300        n               = 10;                 // number of particles
301        resmethod       = 'systematic', or 'multinomial', or 'stratified'
302                                                                                  // resampling method
303        res_threshold   = 0.5;                // resample when active particles drop below 50%
304        \endcode
305        */
306        void from_setting ( const Setting &set ) {
307                BM::from_setting ( set );
308                UI::get ( log_level, set, "log_level", UI::optional );
309               
310                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
311               
312                shared_ptr<epdf> pri = UI::build<epdf> ( set, "prior", UI::compulsory );
313                n =0;
314                UI::get(n,set,"n",UI::optional);;
315                if (n>0){
316                        particles.set_length(n);
317                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
318                        w = ones(n)/n;
319                }
320                set_prior(pri.get());
321                // set resampling method
322                resmethod_from_set ( set );
323                //set drv
324
325                set_yrv ( bm0->_rv() );
326                rvc = bm0->_rvc();
327                BM::set_rv(bm0->_rv());
328                yrv=bm0->_yrv();
329        }
330       
331        void log_register ( bdm::logger& L, const string& prefix ){
332                BM::log_register(L,prefix);
333                if (log_level[logweights]){
334                        L.add_vector( log_level, logweights, RV ( particles.length()), prefix); 
335                }
336                if (log_level[logmeans]){
337                        for (int i=0; i<particles.length(); i++){
338                                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
339                        }
340                }
341        };
342        void log_write ( ) const {
343                BM::log_write();
344                if (log_level[logweights]){
345                        log_level.store( logweights, w); 
346                }
347                if (log_level[logmeans]){
348                        for (int i=0; i<particles.length(); i++){
349                                 log_level.store( logmeans, particles(i)->posterior().mean(), i);
350                        }
351                }
352               
353        }
354       
355        void set_prior(const epdf *pri){
356                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
357                if (emi) {
358                        bdm_assert(particles.length()>0, "initial particle is not assigned");
359                        n = emi->_w().length();
360                        int old_n = particles.length();
361                        if (n!=old_n){
362                                particles.set_length(n,true);
363                        } 
364                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
365                       
366                        for (int i =0; i<n; i++){
367                                particles(i)->set_prior(emi->_com(i));
368                        }
369                } else {
370                        // try to find "n"
371                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
372                        for (int i =0; i<n; i++){
373                                particles(i)->set_prior(pri);
374                        }
375                       
376                }
377        }
378        //! auxiliary function reading parameter 'resmethod' from configuration file
379        void resmethod_from_set ( const Setting &set ) {
380                string resmeth;
381                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
382                        if ( resmeth == "systematic" ) {
383                                resmethod = SYSTEMATIC;
384                        } else  {
385                                if ( resmeth == "multinomial" ) {
386                                        resmethod = MULTINOMIAL;
387                                } else {
388                                        if ( resmeth == "stratified" ) {
389                                                resmethod = STRATIFIED;
390                                        } else {
391                                                bdm_error ( "Unknown resampling method" );
392                                        }
393                                }
394                        }
395                } else {
396                        resmethod = SYSTEMATIC;
397                };
398                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
399                        res_threshold = 0.9;
400                }
401                //validate();
402        }
403
404        void validate() {
405                BM::validate();
406                est.validate();
407                bdm_assert ( n>0, "empty particle pool" );
408                n = w.length();
409                lls = zeros ( n );
410
411                if ( particles(0)->_rv()._dsize() > 0 ) {
412                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV and dimension of posterior" );
413                }
414        }
415        //! resample posterior density (from outside - see MPF)
416        void resample ( ) {
417                ivec ind = zeros_i ( n );
418                bdm::resample(w,ind,resmethod);
419                // copy the internals according to ind
420                for (int i = 0; i < n; i++ ) {
421                        if ( ind ( i ) != i ) {
422                                particles( i ) = particles( ind ( i ) )->_copy();
423                        }
424                        w ( i ) = 1.0 / n;
425                }
426        }
427        //! access function
428        Array<BM*>& _particles() {
429                return particles;
430        }
431
432};
433UIREGISTER ( PF );
434
435/*!
436\brief Marginalized Particle filter
437
438A composition of particle filter with exact (or almost exact) bayesian models (BMs).
439The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFpdf.
440*/
441
442// class MPF : public BM  {
443//      //! Introduces new option
444//      //! \li means - meaning TODO
445//      LOG_LEVEL(MPF,means);
446// protected:
447//      //! particle filter on non-linear variable
448//      shared_ptr<PF> pf;
449//      //! Array of Bayesian models
450//      Array<BM*> BMs;
451//
452//      //! internal class for pdf providing composition of eEmp with external components
453//
454//      class mpfepdf : public epdf  {
455//              //! pointer to particle filter
456//              shared_ptr<PF> &pf;
457//              //! pointer to Array of BMs
458//              Array<BM*> &BMs;
459//      public:
460//              //! constructor
461//              mpfepdf ( shared_ptr<PF> &pf0, Array<BM*> &BMs0 ) : epdf(), pf ( pf0 ), BMs ( BMs0 ) { };
462//              //! a variant of set parameters - this time, parameters are read from BMs and pf
463//              void read_parameters() {
464//                      rv = concat ( pf->posterior()._rv(), BMs ( 0 )->posterior()._rv() );
465//                      dim = pf->posterior().dimension() + BMs ( 0 )->posterior().dimension();
466//                      bdm_assert_debug ( dim == rv._dsize(), "Wrong name " );
467//              }
468//              vec mean() const;
469//
470//              vec variance() const;
471//
472//              void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const;
473//
474//              vec sample() const NOT_IMPLEMENTED(0);
475//
476//              double evallog ( const vec &val ) const NOT_IMPLEMENTED(0);             
477//      };
478//
479//      //! Density joining PF.est with conditional parts
480//      mpfepdf jest;
481//
482//      //! datalink from global yt and cond (Up) to BMs yt and cond (Down)
483//      datalink_m2m this2bm;
484//      //! datalink from global yt and cond (Up) to PFs yt and cond (Down)
485//      datalink_m2m this2pf;
486//      //!datalink from PF part to BM
487//      datalink_part pf2bm;
488//
489// public:
490//      //! Default constructor.
491//      MPF () :  jest ( pf, BMs ) {};
492//      //! set all parameters at once
493//      void set_pf ( shared_ptr<pdf> par0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
494//              if (!pf) pf=new PF;
495//              pf->set_model ( par0, par0 ); // <=== nasty!!!
496//              pf->set_parameters ( n0, rm );
497//              pf->set_rv(par0->_rv());
498//              BMs.set_length ( n0 );
499//      }
500//      //! set a prototype of BM, copy it to as many times as there is particles in pf
501//      void set_BM ( const BM &BMcond0 ) {
502//
503//              int n = pf->__w().length();
504//              BMs.set_length ( n );
505//              // copy
506//              //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
507//              for ( int i = 0; i < n; i++ ) {
508//                      BMs ( i ) = (BM*) BMcond0._copy();
509//              }
510//      };
511//
512//      void bayes ( const vec &yt, const vec &cond );
513//
514//      const epdf& posterior() const {
515//              return jest;
516//      }
517//
518//      //!Access function
519//      const BM* _BM ( int i ) {
520//              return BMs ( i );
521//      }
522//      PF& _pf() {return *pf;}
523//
524//
525//      virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0);
526//             
527//      virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL);
528//     
529//      virtual pdf* predictor() const NOT_IMPLEMENTED(NULL);
530//
531//
532//      /*! configuration structure for basic PF
533//      \code
534//      BM              = BM_class;           // Bayesian filtr for analytical part of the model
535//      parameter_pdf   = pdf_class;         // transitional pdf for non-parametric part of the model
536//      prior           = epdf_class;         // prior probability density
537//      --- optional ---
538//      n               = 10;                 // number of particles
539//      resmethod       = 'systematic', or 'multinomial', or 'stratified'
540//                                                                                // resampling method
541//      \endcode
542//      */
543//      void from_setting ( const Setting &set ) {
544//              BM::from_setting( set );
545//
546//              shared_ptr<pdf> par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
547//
548//              pf = new PF;
549//              // prior must be set before BM
550//              pf->prior_from_set ( set );
551//              pf->resmethod_from_set ( set );
552//              pf->set_model ( par, par ); // too hackish!
553//
554//              shared_ptr<BM> BM0 = UI::build<BM> ( set, "BM", UI::compulsory );
555//              set_BM ( *BM0 );
556//
557//              //set drv
558//              //??set_yrv(concat(BM0->_yrv(),u) );
559//              set_yrv ( BM0->_yrv() );
560//              rvc = BM0->_rvc().subt ( par->_rv() );
561//              //find potential input - what remains in rvc when we subtract rv
562//              RV u = par->_rvc().subt ( par->_rv().copy_t ( -1 ) );
563//              rvc.add ( u );
564//              dimc = rvc._dsize();
565//              validate();
566//      }
567//
568//      void validate() {
569//              BM::validate();
570//              try {
571//                      pf->validate();
572//              } catch ( std::exception ) {
573//                      throw UIException ( "Error in PF part of MPF:" );
574//              }
575//              jest.read_parameters();
576//              this2bm.set_connection ( BMs ( 0 )->_yrv(), BMs ( 0 )->_rvc(), yrv, rvc );
577//              this2pf.set_connection ( pf->_yrv(), pf->_rvc(), yrv, rvc );
578//              pf2bm.set_connection ( BMs ( 0 )->_rvc(), pf->posterior()._rv() );
579//      }
580// };
581// UIREGISTER ( MPF );
582
583/*! ARXg for estimation of state-space variances
584*/
585// class MPF_ARXg :public BM{
586//      protected:
587//      shared_ptr<PF> pf;
588//      //! pointer to Array of BMs
589//      Array<ARX*> BMso;
590//      //! pointer to Array of BMs
591//      Array<ARX*> BMsp;
592//      //!parameter evolution
593//      shared_ptr<fnc> g;
594//      //!observation function
595//      shared_ptr<fnc> h;
596//     
597//      public:
598//              void bayes(const vec &yt, const vec &cond );
599//              void from_setting(const Setting &set) ;
600//              void validate() {
601//                      bdm_assert(g->dimensionc()==g->dimension(),"not supported yet");
602//                      bdm_assert(h->dimensionc()==g->dimension(),"not supported yet");                       
603//              }
604//
605//              double logpred(const vec &cond) const NOT_IMPLEMENTED(0.0);
606//              epdf* epredictor() const NOT_IMPLEMENTED(NULL);
607//              pdf* predictor() const NOT_IMPLEMENTED(NULL);
608//              const epdf& posterior() const {return pf->posterior();};
609//             
610//              void log_register( logger &L, const string &prefix ){
611//                      BM::log_register(L,prefix);
612//                      registered_logger->ids.set_size ( 3 );
613//                      registered_logger->ids(1)= L.add_vector(RV("Q",dimension()*dimension()), prefix+L.prefix_sep()+"Q");
614//                      registered_logger->ids(2)= L.add_vector(RV("R",dimensiony()*dimensiony()), prefix+L.prefix_sep()+"R");
615//                     
616//              };
617//              void log_write() const {
618//                      BM::log_write();
619//                      mat mQ=zeros(dimension(),dimension());
620//                      mat pom=zeros(dimension(),dimension());
621//                      mat mR=zeros(dimensiony(),dimensiony());
622//                      mat pom2=zeros(dimensiony(),dimensiony());
623//                      mat dum;
624//                      const vec w=pf->posterior()._w();
625//                      for (int i=0; i<w.length(); i++){
626//                              BMsp(i)->posterior().mean_mat(dum,pom);
627//                              mQ += w(i) * pom;
628//                              BMso(i)->posterior().mean_mat(dum,pom2);
629//                              mR += w(i) * pom2;
630//                             
631//                      }
632//                      registered_logger->L.log_vector ( registered_logger->ids ( 1 ), cvectorize(mQ) );
633//                      registered_logger->L.log_vector ( registered_logger->ids ( 2 ), cvectorize(mR) );
634//                     
635//              }
636// };
637// UIREGISTER(MPF_ARXg);
638
639
640}
641#endif // KF_H
642
Note: See TracBrowser for help on using the browser.