root/library/bdm/estim/particles.h @ 900

Revision 900, 17.3 kB (checked in by smidl, 14 years ago)

particle bug fixing

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticle : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29        //! pdf with for transitional par
30        shared_ptr<pdf> par; // pdf for non-linear part
31        //! link from this to bm
32        shared_ptr<datalink_part> cond2bm;
33        //! link from cond to par
34        shared_ptr<datalink_part> cond2par;
35        //! link from emp 2 par
36        shared_ptr<datalink_part> emp2bm;
37        //! link from emp 2 par
38        shared_ptr<datalink_part> emp2par;
39       
40        //! custom posterior - product
41        class eprod_2:public eprod_base {
42                protected:
43                MarginalizedParticle &mp;
44                public:
45                eprod_2(MarginalizedParticle &m):mp(m){}
46                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
47                const int no_factors() const {return 2;}
48        } est;
49       
50        public:
51                MarginalizedParticle():est(*this){};
52                MarginalizedParticle(const MarginalizedParticle &m2):est(*this){
53                        bm = m2.bm->_copy();
54                        est_emp = m2.est_emp;
55                        par = m2.par;
56                        validate();
57                };
58                BM* _copy() const{return new MarginalizedParticle(*this);};
59                void bayes(const vec &dt, const vec &cond){
60                        vec par_cond(par->dimensionc());
61                        cond2par->filldown(cond,par_cond); // copy ut
62                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
63                       
64                        //sample new particle
65                        est_emp.set_point(par->samplecond(par_cond));
66                        //if (evalll)
67                        vec bm_cond(bm->dimensionc());
68                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
69                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
70                        bm->bayes(dt, bm_cond);
71                        ll=bm->_ll();
72                }
73                const eprod_2& posterior() const {return est;}
74       
75                void set_prior(const epdf *pdf0){
76                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
77                        if (ep){ // full prior
78                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
79                                bm->set_prior(ep->factor(0));
80                                est_emp.set_point(ep->factor(1)->sample());
81                        } else {
82                                // assume prior is only for emp;
83                                est_emp.set_point(pdf0->sample());
84                        }
85                }
86
87                /*! parse structure
88                \code
89                class = "BootstrapParticle";
90                parameter_pdf = {class = 'epdf_offspring', ...};
91                observation_pdf = {class = 'epdf_offspring',...};
92                \endcode
93                If rvs are set, then it checks for compatibility.
94                */
95                void from_setting(const Setting &set){
96                        BM::from_setting ( set );
97                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
98                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
99                }
100                void validate(){
101                        est_emp.set_point(zeros(par->dimension()));
102                        est.validate();
103                       
104                        yrv = bm->_yrv();
105                        dimy = bm->dimensiony();
106                        set_rv( concat(bm->_rv(), par->_rv()));
107                        set_dim( par->dimension()+bm->dimension());
108
109                        rvc = par->_rvc();
110                        rvc.add(bm->_rvc());
111                        rvc=rvc.subt(par->_rv());
112                        rvc=rvc.subt(par->_rv().copy_t(-1));
113                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
114                       
115                        cond2bm=new datalink_part;
116                        cond2par=new datalink_part;
117                        emp2bm  =new datalink_part;
118                        emp2par =new datalink_part;
119                        cond2bm->set_connection(bm->_rvc(), rvc);
120                        cond2par->set_connection(par->_rvc(), rvc);
121                        emp2bm->set_connection(bm->_rvc(), par->_rv());
122                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
123                       
124                        dimc = rvc._dsize();
125                };
126};
127UIREGISTER(MarginalizedParticle);
128
129//! class used in PF
130class BootstrapParticle : public BM{
131        dirac est;
132        shared_ptr<pdf> par;
133        shared_ptr<pdf> obs;
134        shared_ptr<datalink_part> cond2par;
135        shared_ptr<datalink_part> cond2obs;
136        shared_ptr<datalink_part> xt2obs;
137        shared_ptr<datalink_part> xtm2par;
138        public:
139                BM* _copy() const{return new BootstrapParticle(*this);};
140                void bayes(const vec &dt, const vec &cond){
141                        vec par_cond(par->dimensionc());
142                        cond2par->filldown(cond,par_cond); // copy ut
143                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
144                       
145                        //sample new particle
146                        est.set_point(par->samplecond(par_cond));
147                        //if (evalll)
148                        vec obs_cond(obs->dimensionc());
149                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
150                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
151                        ll=obs->evallogcond(dt,obs_cond);
152                }
153                const dirac& posterior() const {return est;}
154               
155                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
156               
157                /*! parse structure
158                \code
159                class = "BootstrapParticle";
160                parameter_pdf = {class = 'epdf_offspring', ...};
161                observation_pdf = {class = 'epdf_offspring',...};
162                \endcode
163                If rvs are set, then it checks for compatibility.
164                */
165                void from_setting(const Setting &set){
166                        BM::from_setting ( set );
167                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
168                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
169                }
170                void validate(){
171                        yrv = obs->_rv();
172                        dimy = obs->dimension();
173                        set_rv( par->_rv());
174                        set_dim( par->dimension());
175                       
176                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
177                        rvc.add(obs->_rvc()); //
178                       
179                        cond2obs=new datalink_part;
180                        cond2par=new datalink_part;
181                        xt2obs  =new datalink_part;
182                        xtm2par =new datalink_part;
183                        cond2obs->set_connection(obs->_rvc(), rvc);
184                        cond2par->set_connection(par->_rvc(), rvc);
185                        xt2obs->set_connection(obs->_rvc(), _rv());
186                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
187                       
188                        dimc = rvc._dsize();
189                };
190};
191UIREGISTER(BootstrapParticle);
192
193
194/*!
195* \brief Trivial particle filter with proposal density equal to parameter evolution model.
196
197Posterior density is represented by a weighted empirical density (\c eEmp ).
198*/
199
200class PF : public BM {
201        //! \var log_level_enums weights
202        //! all weightes will be logged
203
204        //! \var log_level_enums samples
205        //! all samples will be logged
206        LOG_LEVEL(PF,weights,samples);
207       
208        class pf_mix: public emix_base{
209                Array<BM*> &bms;
210                public:
211                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
212                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
213                        int no_coms() const {return bms.length();}
214        };
215protected:
216        //!number of particles;
217        int n;
218        //!posterior density
219        pf_mix est;
220        //! weights;
221        vec w;
222        //! particles
223        Array<BM*> particles;
224        //! internal structure storing loglikelihood of predictions
225        vec lls;
226
227        //! which resampling method will be used
228        RESAMPLING_METHOD resmethod;
229        //! resampling threshold; in this case its meaning is minimum ratio of active particles
230        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
231        double res_threshold;
232
233        //! \name Options
234        //!@{
235        //!@}
236
237public:
238        //! \name Constructors
239        //!@{
240        PF ( ) : est(w,particles) { };
241
242        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
243                n = n0;
244                res_threshold = res_th0;
245                resmethod = rm;
246        };
247        void set_model ( const BM *particle0, const epdf *prior) {
248                if (n>0){
249                        particles.set_length(n);
250                        for (int i=0; i<n;i++){
251                                particles(i) = particle0->_copy();
252                                particles(i)->set_prior(prior);
253                        }
254                }
255                // set values for posterior
256                est.set_rv ( particle0->posterior()._rv() );
257        };
258        void set_statistics ( const vec w0, const epdf &epdf0 ) {
259                //est.set_statistics ( w0, epdf0 );
260        };
261/*      void set_statistics ( const eEmp &epdf0 ) {
262                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
263                est = epdf0;
264        };*/
265        //!@}
266
267        //! bayes compute weights of the
268        virtual void bayes_weights();
269        //! important part of particle filtering - decide if it is time to perform resampling
270        virtual bool do_resampling() {
271                double eff = 1.0 / ( w * w );
272                return eff < ( res_threshold*n );
273        }
274        void bayes ( const vec &yt, const vec &cond );
275        //!access function
276        vec& _lls() {
277                return lls;
278        }
279        //!access function
280        RESAMPLING_METHOD _resmethod() const {
281                return resmethod;
282        }
283        //! return correctly typed posterior (covariant return)
284        const pf_mix& posterior() const {
285                return est;
286        }
287
288        /*! configuration structure for basic PF
289        \code
290        parameter_pdf   = pdf_class;         // parameter evolution pdf
291        observation_pdf = pdf_class;         // observation pdf
292        prior           = epdf_class;         // prior probability density
293        --- optional ---
294        n               = 10;                 // number of particles
295        resmethod       = 'systematic', or 'multinomial', or 'stratified'
296                                                                                  // resampling method
297        res_threshold   = 0.5;                // resample when active particles drop below 50%
298        \endcode
299        */
300        void from_setting ( const Setting &set ) {
301                BM::from_setting ( set );
302               
303                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
304               
305                shared_ptr<epdf> pri = UI::build<epdf> ( set, "prior", UI::compulsory );
306                n =0;
307                UI::get(n,set,"n",UI::optional);;
308                if (n>0){
309                        particles.set_length(n);
310                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
311                        w = ones(n)/n;
312                }
313                set_prior(pri.get());
314                // set resampling method
315                resmethod_from_set ( set );
316                //set drv
317
318                set_yrv ( bm0->_rv() );
319                rvc = bm0->_rvc();
320                BM::set_rv(bm0->_rv());
321                yrv=bm0->_yrv();
322        }
323       
324        void set_prior(const epdf *pri){
325                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
326                if (emi) {
327                        bdm_assert(particles.length()>0, "initial particle is not assigned");
328                        n = emi->_w().length();
329                        int old_n = particles.length();
330                        if (n!=old_n){
331                                particles.set_length(n,true);
332                        } 
333                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
334                       
335                        for (int i =0; i<n; i++){
336                                particles(i)->set_prior(emi->_com(i));
337                        }
338                } else {
339                        // try to find "n"
340                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
341                        for (int i =0; i<n; i++){
342                                particles(i)->set_prior(pri);
343                        }
344                       
345                }
346        }
347        //! auxiliary function reading parameter 'resmethod' from configuration file
348        void resmethod_from_set ( const Setting &set ) {
349                string resmeth;
350                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
351                        if ( resmeth == "systematic" ) {
352                                resmethod = SYSTEMATIC;
353                        } else  {
354                                if ( resmeth == "multinomial" ) {
355                                        resmethod = MULTINOMIAL;
356                                } else {
357                                        if ( resmeth == "stratified" ) {
358                                                resmethod = STRATIFIED;
359                                        } else {
360                                                bdm_error ( "Unknown resampling method" );
361                                        }
362                                }
363                        }
364                } else {
365                        resmethod = SYSTEMATIC;
366                };
367                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
368                        res_threshold = 0.9;
369                }
370                //validate();
371        }
372
373        void validate() {
374                BM::validate();
375                est.validate();
376                bdm_assert ( n>0, "empty particle pool" );
377                n = w.length();
378                lls = zeros ( n );
379
380                if ( particles(0)->_rv()._dsize() > 0 ) {
381                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV and dimension of posterior" );
382                }
383        }
384        //! resample posterior density (from outside - see MPF)
385        void resample ( ) {
386                ivec ind = zeros_i ( n );
387                bdm::resample(w,ind,resmethod);
388                // copy the internals according to ind
389                for (int i = 0; i < n; i++ ) {
390                        if ( ind ( i ) != i ) {
391                                particles( i ) = particles( ind ( i ) )->_copy();
392                        }
393                        w ( i ) = 1.0 / n;
394                }
395        }
396        //! access function
397        Array<BM*>& _particles() {
398                return particles;
399        }
400
401};
402UIREGISTER ( PF );
403
404/*!
405\brief Marginalized Particle filter
406
407A composition of particle filter with exact (or almost exact) bayesian models (BMs).
408The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFpdf.
409*/
410
411// class MPF : public BM  {
412//      //! Introduces new option
413//      //! \li means - meaning TODO
414//      LOG_LEVEL(MPF,means);
415// protected:
416//      //! particle filter on non-linear variable
417//      shared_ptr<PF> pf;
418//      //! Array of Bayesian models
419//      Array<BM*> BMs;
420//
421//      //! internal class for pdf providing composition of eEmp with external components
422//
423//      class mpfepdf : public epdf  {
424//              //! pointer to particle filter
425//              shared_ptr<PF> &pf;
426//              //! pointer to Array of BMs
427//              Array<BM*> &BMs;
428//      public:
429//              //! constructor
430//              mpfepdf ( shared_ptr<PF> &pf0, Array<BM*> &BMs0 ) : epdf(), pf ( pf0 ), BMs ( BMs0 ) { };
431//              //! a variant of set parameters - this time, parameters are read from BMs and pf
432//              void read_parameters() {
433//                      rv = concat ( pf->posterior()._rv(), BMs ( 0 )->posterior()._rv() );
434//                      dim = pf->posterior().dimension() + BMs ( 0 )->posterior().dimension();
435//                      bdm_assert_debug ( dim == rv._dsize(), "Wrong name " );
436//              }
437//              vec mean() const;
438//
439//              vec variance() const;
440//
441//              void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const;
442//
443//              vec sample() const NOT_IMPLEMENTED(0);
444//
445//              double evallog ( const vec &val ) const NOT_IMPLEMENTED(0);             
446//      };
447//
448//      //! Density joining PF.est with conditional parts
449//      mpfepdf jest;
450//
451//      //! datalink from global yt and cond (Up) to BMs yt and cond (Down)
452//      datalink_m2m this2bm;
453//      //! datalink from global yt and cond (Up) to PFs yt and cond (Down)
454//      datalink_m2m this2pf;
455//      //!datalink from PF part to BM
456//      datalink_part pf2bm;
457//
458// public:
459//      //! Default constructor.
460//      MPF () :  jest ( pf, BMs ) {};
461//      //! set all parameters at once
462//      void set_pf ( shared_ptr<pdf> par0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
463//              if (!pf) pf=new PF;
464//              pf->set_model ( par0, par0 ); // <=== nasty!!!
465//              pf->set_parameters ( n0, rm );
466//              pf->set_rv(par0->_rv());
467//              BMs.set_length ( n0 );
468//      }
469//      //! set a prototype of BM, copy it to as many times as there is particles in pf
470//      void set_BM ( const BM &BMcond0 ) {
471//
472//              int n = pf->__w().length();
473//              BMs.set_length ( n );
474//              // copy
475//              //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
476//              for ( int i = 0; i < n; i++ ) {
477//                      BMs ( i ) = (BM*) BMcond0._copy();
478//              }
479//      };
480//
481//      void bayes ( const vec &yt, const vec &cond );
482//
483//      const epdf& posterior() const {
484//              return jest;
485//      }
486//
487//      //!Access function
488//      const BM* _BM ( int i ) {
489//              return BMs ( i );
490//      }
491//      PF& _pf() {return *pf;}
492//
493//
494//      virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0);
495//             
496//      virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL);
497//     
498//      virtual pdf* predictor() const NOT_IMPLEMENTED(NULL);
499//
500//
501//      /*! configuration structure for basic PF
502//      \code
503//      BM              = BM_class;           // Bayesian filtr for analytical part of the model
504//      parameter_pdf   = pdf_class;         // transitional pdf for non-parametric part of the model
505//      prior           = epdf_class;         // prior probability density
506//      --- optional ---
507//      n               = 10;                 // number of particles
508//      resmethod       = 'systematic', or 'multinomial', or 'stratified'
509//                                                                                // resampling method
510//      \endcode
511//      */
512//      void from_setting ( const Setting &set ) {
513//              BM::from_setting( set );
514//
515//              shared_ptr<pdf> par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
516//
517//              pf = new PF;
518//              // prior must be set before BM
519//              pf->prior_from_set ( set );
520//              pf->resmethod_from_set ( set );
521//              pf->set_model ( par, par ); // too hackish!
522//
523//              shared_ptr<BM> BM0 = UI::build<BM> ( set, "BM", UI::compulsory );
524//              set_BM ( *BM0 );
525//
526//              //set drv
527//              //??set_yrv(concat(BM0->_yrv(),u) );
528//              set_yrv ( BM0->_yrv() );
529//              rvc = BM0->_rvc().subt ( par->_rv() );
530//              //find potential input - what remains in rvc when we subtract rv
531//              RV u = par->_rvc().subt ( par->_rv().copy_t ( -1 ) );
532//              rvc.add ( u );
533//              dimc = rvc._dsize();
534//              validate();
535//      }
536//
537//      void validate() {
538//              BM::validate();
539//              try {
540//                      pf->validate();
541//              } catch ( std::exception ) {
542//                      throw UIException ( "Error in PF part of MPF:" );
543//              }
544//              jest.read_parameters();
545//              this2bm.set_connection ( BMs ( 0 )->_yrv(), BMs ( 0 )->_rvc(), yrv, rvc );
546//              this2pf.set_connection ( pf->_yrv(), pf->_rvc(), yrv, rvc );
547//              pf2bm.set_connection ( BMs ( 0 )->_rvc(), pf->posterior()._rv() );
548//      }
549// };
550// UIREGISTER ( MPF );
551
552/*! ARXg for estimation of state-space variances
553*/
554// class MPF_ARXg :public BM{
555//      protected:
556//      shared_ptr<PF> pf;
557//      //! pointer to Array of BMs
558//      Array<ARX*> BMso;
559//      //! pointer to Array of BMs
560//      Array<ARX*> BMsp;
561//      //!parameter evolution
562//      shared_ptr<fnc> g;
563//      //!observation function
564//      shared_ptr<fnc> h;
565//     
566//      public:
567//              void bayes(const vec &yt, const vec &cond );
568//              void from_setting(const Setting &set) ;
569//              void validate() {
570//                      bdm_assert(g->dimensionc()==g->dimension(),"not supported yet");
571//                      bdm_assert(h->dimensionc()==g->dimension(),"not supported yet");                       
572//              }
573//
574//              double logpred(const vec &cond) const NOT_IMPLEMENTED(0.0);
575//              epdf* epredictor() const NOT_IMPLEMENTED(NULL);
576//              pdf* predictor() const NOT_IMPLEMENTED(NULL);
577//              const epdf& posterior() const {return pf->posterior();};
578//             
579//              void log_register( logger &L, const string &prefix ){
580//                      BM::log_register(L,prefix);
581//                      logrec->ids.set_size ( 3 );
582//                      logrec->ids(1)= L.add_vector(RV("Q",dimension()*dimension()), prefix+L.prefix_sep()+"Q");
583//                      logrec->ids(2)= L.add_vector(RV("R",dimensiony()*dimensiony()), prefix+L.prefix_sep()+"R");
584//                     
585//              };
586//              void log_write() const {
587//                      BM::log_write();
588//                      mat mQ=zeros(dimension(),dimension());
589//                      mat pom=zeros(dimension(),dimension());
590//                      mat mR=zeros(dimensiony(),dimensiony());
591//                      mat pom2=zeros(dimensiony(),dimensiony());
592//                      mat dum;
593//                      const vec w=pf->posterior()._w();
594//                      for (int i=0; i<w.length(); i++){
595//                              BMsp(i)->posterior().mean_mat(dum,pom);
596//                              mQ += w(i) * pom;
597//                              BMso(i)->posterior().mean_mat(dum,pom2);
598//                              mR += w(i) * pom2;
599//                             
600//                      }
601//                      logrec->L.log_vector ( logrec->ids ( 1 ), cvectorize(mQ) );
602//                      logrec->L.log_vector ( logrec->ids ( 2 ), cvectorize(mR) );
603//                     
604//              }
605// };
606// UIREGISTER(MPF_ARXg);
607
608
609}
610#endif // KF_H
611
Note: See TracBrowser for help on using the browser.