root/library/bdm/estim/particles.h @ 950

Revision 950, 18.2 kB (checked in by smidl, 14 years ago)

Move logfull out of epdf.

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticle : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29        //! pdf with for transitional par
30        shared_ptr<pdf> par; // pdf for non-linear part
31        //! link from this to bm
32        shared_ptr<datalink_part> cond2bm;
33        //! link from cond to par
34        shared_ptr<datalink_part> cond2par;
35        //! link from emp 2 par
36        shared_ptr<datalink_part> emp2bm;
37        //! link from emp 2 par
38        shared_ptr<datalink_part> emp2par;
39       
40        //! custom posterior - product
41        class eprod_2:public eprod_base {
42                protected:
43                MarginalizedParticle &mp;
44                public:
45                eprod_2(MarginalizedParticle &m):mp(m){}
46                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
47                const int no_factors() const {return 2;}
48        } est;
49       
50        public:
51                MarginalizedParticle():est(*this){};
52                MarginalizedParticle(const MarginalizedParticle &m2):est(*this){
53                        bm = m2.bm->_copy();
54                        est_emp = m2.est_emp;
55                        par = m2.par;
56                        validate();
57                };
58                BM* _copy() const{return new MarginalizedParticle(*this);};
59                void bayes(const vec &dt, const vec &cond){
60                        vec par_cond(par->dimensionc());
61                        cond2par->filldown(cond,par_cond); // copy ut
62                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
63                       
64                        //sample new particle
65                        est_emp.set_point(par->samplecond(par_cond));
66                        //if (evalll)
67                        vec bm_cond(bm->dimensionc());
68                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
69                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
70                        bm->bayes(dt, bm_cond);
71                        ll=bm->_ll();
72                }
73                const eprod_2& posterior() const {return est;}
74       
75                void set_prior(const epdf *pdf0){
76                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
77                        if (ep){ // full prior
78                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
79                                bm->set_prior(ep->factor(0));
80                                est_emp.set_point(ep->factor(1)->sample());
81                        } else {
82                                // assume prior is only for emp;
83                                est_emp.set_point(pdf0->sample());
84                        }
85                }
86
87                /*! parse structure
88                \code
89                class = "BootstrapParticle";
90                parameter_pdf = {class = 'epdf_offspring', ...};
91                observation_pdf = {class = 'epdf_offspring',...};
92                \endcode
93                If rvs are set, then it checks for compatibility.
94                */
95                void from_setting(const Setting &set){
96                        BM::from_setting ( set );
97                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
98                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
99                }
100               
101                void to_setting(const Setting &set){
102                        if (BM::log_level[logfull]){
103                        }
104                }
105                void validate(){
106                        est_emp.set_point(zeros(par->dimension()));
107                        est.validate();
108                       
109                        yrv = bm->_yrv();
110                        dimy = bm->dimensiony();
111                        set_rv( concat(bm->_rv(), par->_rv()));
112                        set_dim( par->dimension()+bm->dimension());
113
114                        rvc = par->_rvc();
115                        rvc.add(bm->_rvc());
116                        rvc=rvc.subt(par->_rv());
117                        rvc=rvc.subt(par->_rv().copy_t(-1));
118                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
119                       
120                        cond2bm=new datalink_part;
121                        cond2par=new datalink_part;
122                        emp2bm  =new datalink_part;
123                        emp2par =new datalink_part;
124                        cond2bm->set_connection(bm->_rvc(), rvc);
125                        cond2par->set_connection(par->_rvc(), rvc);
126                        emp2bm->set_connection(bm->_rvc(), par->_rv());
127                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
128                       
129                        dimc = rvc._dsize();
130                };
131};
132UIREGISTER(MarginalizedParticle);
133
134//! class used in PF
135class BootstrapParticle : public BM{
136        dirac est;
137        shared_ptr<pdf> par;
138        shared_ptr<pdf> obs;
139        shared_ptr<datalink_part> cond2par;
140        shared_ptr<datalink_part> cond2obs;
141        shared_ptr<datalink_part> xt2obs;
142        shared_ptr<datalink_part> xtm2par;
143        public:
144                BM* _copy() const{return new BootstrapParticle(*this);};
145                void bayes(const vec &dt, const vec &cond){
146                        vec par_cond(par->dimensionc());
147                        cond2par->filldown(cond,par_cond); // copy ut
148                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
149                       
150                        //sample new particle
151                        est.set_point(par->samplecond(par_cond));
152                        //if (evalll)
153                        vec obs_cond(obs->dimensionc());
154                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
155                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
156                        ll=obs->evallogcond(dt,obs_cond);
157                }
158                const dirac& posterior() const {return est;}
159               
160                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
161               
162                /*! parse structure
163                \code
164                class = "BootstrapParticle";
165                parameter_pdf = {class = 'epdf_offspring', ...};
166                observation_pdf = {class = 'epdf_offspring',...};
167                \endcode
168                If rvs are set, then it checks for compatibility.
169                */
170                void from_setting(const Setting &set){
171                        BM::from_setting ( set );
172                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
173                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
174                }
175                void validate(){
176                        yrv = obs->_rv();
177                        dimy = obs->dimension();
178                        set_rv( par->_rv());
179                        set_dim( par->dimension());
180                       
181                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
182                        rvc.add(obs->_rvc()); //
183                       
184                        cond2obs=new datalink_part;
185                        cond2par=new datalink_part;
186                        xt2obs  =new datalink_part;
187                        xtm2par =new datalink_part;
188                        cond2obs->set_connection(obs->_rvc(), rvc);
189                        cond2par->set_connection(par->_rvc(), rvc);
190                        xt2obs->set_connection(obs->_rvc(), _rv());
191                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
192                       
193                        dimc = rvc._dsize();
194                };
195};
196UIREGISTER(BootstrapParticle);
197
198
199/*!
200* \brief Trivial particle filter with proposal density equal to parameter evolution model.
201
202Posterior density is represented by a weighted empirical density (\c eEmp ).
203*/
204
205class PF : public BM {
206        //! \var log_level_enums weights
207        //! all weightes will be logged
208
209        //! \var log_level_enums menas
210        //! means of particles will be logged
211        LOG_LEVEL(PF,logweights,logmeans,logvars);
212       
213        class pf_mix: public emix_base{
214                Array<BM*> &bms;
215                public:
216                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
217                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
218                        int no_coms() const {return bms.length();}
219        };
220protected:
221        //!number of particles;
222        int n;
223        //!posterior density
224        pf_mix est;
225        //! weights;
226        vec w;
227        //! particles
228        Array<BM*> particles;
229        //! internal structure storing loglikelihood of predictions
230        vec lls;
231
232        //! which resampling method will be used
233        RESAMPLING_METHOD resmethod;
234        //! resampling threshold; in this case its meaning is minimum ratio of active particles
235        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
236        double res_threshold;
237
238        //! \name Options
239        //!@{
240        //!@}
241
242public:
243        //! \name Constructors
244        //!@{
245        PF ( ) : est(w,particles) { };
246
247        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
248                n = n0;
249                res_threshold = res_th0;
250                resmethod = rm;
251        };
252        void set_model ( const BM *particle0, const epdf *prior) {
253                if (n>0){
254                        particles.set_length(n);
255                        for (int i=0; i<n;i++){
256                                particles(i) = particle0->_copy();
257                                particles(i)->set_prior(prior);
258                        }
259                }
260                // set values for posterior
261                est.set_rv ( particle0->posterior()._rv() );
262        };
263        void set_statistics ( const vec w0, const epdf &epdf0 ) {
264                //est.set_statistics ( w0, epdf0 );
265        };
266/*      void set_statistics ( const eEmp &epdf0 ) {
267                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
268                est = epdf0;
269        };*/
270        //!@}
271
272        //! bayes compute weights of the
273        virtual void bayes_weights();
274        //! important part of particle filtering - decide if it is time to perform resampling
275        virtual bool do_resampling() {
276                double eff = 1.0 / ( w * w );
277                return eff < ( res_threshold*n );
278        }
279        void bayes ( const vec &yt, const vec &cond );
280        //!access function
281        vec& _lls() {
282                return lls;
283        }
284        //!access function
285        RESAMPLING_METHOD _resmethod() const {
286                return resmethod;
287        }
288        //! return correctly typed posterior (covariant return)
289        const pf_mix& posterior() const {
290                return est;
291        }
292
293        /*! configuration structure for basic PF
294        \code
295        parameter_pdf   = pdf_class;         // parameter evolution pdf
296        observation_pdf = pdf_class;         // observation pdf
297        prior           = epdf_class;         // prior probability density
298        --- optional ---
299        n               = 10;                 // number of particles
300        resmethod       = 'systematic', or 'multinomial', or 'stratified'
301                                                                                  // resampling method
302        res_threshold   = 0.5;                // resample when active particles drop below 50%
303        \endcode
304        */
305        void from_setting ( const Setting &set ) {
306                BM::from_setting ( set );
307                UI::get ( log_level, set, "log_level", UI::optional );
308               
309                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
310               
311                shared_ptr<epdf> pri = UI::build<epdf> ( set, "prior", UI::compulsory );
312                n =0;
313                UI::get(n,set,"n",UI::optional);;
314                if (n>0){
315                        particles.set_length(n);
316                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
317                        w = ones(n)/n;
318                }
319                set_prior(pri.get());
320                // set resampling method
321                resmethod_from_set ( set );
322                //set drv
323
324                set_yrv ( bm0->_rv() );
325                rvc = bm0->_rvc();
326                BM::set_rv(bm0->_rv());
327                yrv=bm0->_yrv();
328        }
329       
330        void log_register ( bdm::logger& L, const string& prefix ){
331                BM::log_register(L,prefix);
332                if (log_level[logweights]){
333                        L.add_vector( log_level, logweights, RV ( particles.length()), prefix); 
334                }
335                if (log_level[logmeans]){
336                        for (int i=0; i<particles.length(); i++){
337                                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
338                        }
339                }
340        };
341        void log_write ( ) const {
342                BM::log_write();
343                if (log_level[logweights]){
344                        log_level.store( logweights, w); 
345                }
346                if (log_level[logmeans]){
347                        for (int i=0; i<particles.length(); i++){
348                                 log_level.store( logmeans, particles(i)->posterior().mean(), i);
349                        }
350                }
351               
352        }
353       
354        void set_prior(const epdf *pri){
355                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
356                if (emi) {
357                        bdm_assert(particles.length()>0, "initial particle is not assigned");
358                        n = emi->_w().length();
359                        int old_n = particles.length();
360                        if (n!=old_n){
361                                particles.set_length(n,true);
362                        } 
363                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
364                       
365                        for (int i =0; i<n; i++){
366                                particles(i)->set_prior(emi->_com(i));
367                        }
368                } else {
369                        // try to find "n"
370                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
371                        for (int i =0; i<n; i++){
372                                particles(i)->set_prior(pri);
373                        }
374                       
375                }
376        }
377        //! auxiliary function reading parameter 'resmethod' from configuration file
378        void resmethod_from_set ( const Setting &set ) {
379                string resmeth;
380                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
381                        if ( resmeth == "systematic" ) {
382                                resmethod = SYSTEMATIC;
383                        } else  {
384                                if ( resmeth == "multinomial" ) {
385                                        resmethod = MULTINOMIAL;
386                                } else {
387                                        if ( resmeth == "stratified" ) {
388                                                resmethod = STRATIFIED;
389                                        } else {
390                                                bdm_error ( "Unknown resampling method" );
391                                        }
392                                }
393                        }
394                } else {
395                        resmethod = SYSTEMATIC;
396                };
397                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
398                        res_threshold = 0.9;
399                }
400                //validate();
401        }
402
403        void validate() {
404                BM::validate();
405                est.validate();
406                bdm_assert ( n>0, "empty particle pool" );
407                n = w.length();
408                lls = zeros ( n );
409
410                if ( particles(0)->_rv()._dsize() > 0 ) {
411                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV and dimension of posterior" );
412                }
413        }
414        //! resample posterior density (from outside - see MPF)
415        void resample ( ) {
416                ivec ind = zeros_i ( n );
417                bdm::resample(w,ind,resmethod);
418                // copy the internals according to ind
419                for (int i = 0; i < n; i++ ) {
420                        if ( ind ( i ) != i ) {
421                                particles( i ) = particles( ind ( i ) )->_copy();
422                        }
423                        w ( i ) = 1.0 / n;
424                }
425        }
426        //! access function
427        Array<BM*>& _particles() {
428                return particles;
429        }
430
431};
432UIREGISTER ( PF );
433
434/*!
435\brief Marginalized Particle filter
436
437A composition of particle filter with exact (or almost exact) bayesian models (BMs).
438The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFpdf.
439*/
440
441// class MPF : public BM  {
442//      //! Introduces new option
443//      //! \li means - meaning TODO
444//      LOG_LEVEL(MPF,means);
445// protected:
446//      //! particle filter on non-linear variable
447//      shared_ptr<PF> pf;
448//      //! Array of Bayesian models
449//      Array<BM*> BMs;
450//
451//      //! internal class for pdf providing composition of eEmp with external components
452//
453//      class mpfepdf : public epdf  {
454//              //! pointer to particle filter
455//              shared_ptr<PF> &pf;
456//              //! pointer to Array of BMs
457//              Array<BM*> &BMs;
458//      public:
459//              //! constructor
460//              mpfepdf ( shared_ptr<PF> &pf0, Array<BM*> &BMs0 ) : epdf(), pf ( pf0 ), BMs ( BMs0 ) { };
461//              //! a variant of set parameters - this time, parameters are read from BMs and pf
462//              void read_parameters() {
463//                      rv = concat ( pf->posterior()._rv(), BMs ( 0 )->posterior()._rv() );
464//                      dim = pf->posterior().dimension() + BMs ( 0 )->posterior().dimension();
465//                      bdm_assert_debug ( dim == rv._dsize(), "Wrong name " );
466//              }
467//              vec mean() const;
468//
469//              vec variance() const;
470//
471//              void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const;
472//
473//              vec sample() const NOT_IMPLEMENTED(0);
474//
475//              double evallog ( const vec &val ) const NOT_IMPLEMENTED(0);             
476//      };
477//
478//      //! Density joining PF.est with conditional parts
479//      mpfepdf jest;
480//
481//      //! datalink from global yt and cond (Up) to BMs yt and cond (Down)
482//      datalink_m2m this2bm;
483//      //! datalink from global yt and cond (Up) to PFs yt and cond (Down)
484//      datalink_m2m this2pf;
485//      //!datalink from PF part to BM
486//      datalink_part pf2bm;
487//
488// public:
489//      //! Default constructor.
490//      MPF () :  jest ( pf, BMs ) {};
491//      //! set all parameters at once
492//      void set_pf ( shared_ptr<pdf> par0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
493//              if (!pf) pf=new PF;
494//              pf->set_model ( par0, par0 ); // <=== nasty!!!
495//              pf->set_parameters ( n0, rm );
496//              pf->set_rv(par0->_rv());
497//              BMs.set_length ( n0 );
498//      }
499//      //! set a prototype of BM, copy it to as many times as there is particles in pf
500//      void set_BM ( const BM &BMcond0 ) {
501//
502//              int n = pf->__w().length();
503//              BMs.set_length ( n );
504//              // copy
505//              //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
506//              for ( int i = 0; i < n; i++ ) {
507//                      BMs ( i ) = (BM*) BMcond0._copy();
508//              }
509//      };
510//
511//      void bayes ( const vec &yt, const vec &cond );
512//
513//      const epdf& posterior() const {
514//              return jest;
515//      }
516//
517//      //!Access function
518//      const BM* _BM ( int i ) {
519//              return BMs ( i );
520//      }
521//      PF& _pf() {return *pf;}
522//
523//
524//      virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0);
525//             
526//      virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL);
527//     
528//      virtual pdf* predictor() const NOT_IMPLEMENTED(NULL);
529//
530//
531//      /*! configuration structure for basic PF
532//      \code
533//      BM              = BM_class;           // Bayesian filtr for analytical part of the model
534//      parameter_pdf   = pdf_class;         // transitional pdf for non-parametric part of the model
535//      prior           = epdf_class;         // prior probability density
536//      --- optional ---
537//      n               = 10;                 // number of particles
538//      resmethod       = 'systematic', or 'multinomial', or 'stratified'
539//                                                                                // resampling method
540//      \endcode
541//      */
542//      void from_setting ( const Setting &set ) {
543//              BM::from_setting( set );
544//
545//              shared_ptr<pdf> par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
546//
547//              pf = new PF;
548//              // prior must be set before BM
549//              pf->prior_from_set ( set );
550//              pf->resmethod_from_set ( set );
551//              pf->set_model ( par, par ); // too hackish!
552//
553//              shared_ptr<BM> BM0 = UI::build<BM> ( set, "BM", UI::compulsory );
554//              set_BM ( *BM0 );
555//
556//              //set drv
557//              //??set_yrv(concat(BM0->_yrv(),u) );
558//              set_yrv ( BM0->_yrv() );
559//              rvc = BM0->_rvc().subt ( par->_rv() );
560//              //find potential input - what remains in rvc when we subtract rv
561//              RV u = par->_rvc().subt ( par->_rv().copy_t ( -1 ) );
562//              rvc.add ( u );
563//              dimc = rvc._dsize();
564//              validate();
565//      }
566//
567//      void validate() {
568//              BM::validate();
569//              try {
570//                      pf->validate();
571//              } catch ( std::exception ) {
572//                      throw UIException ( "Error in PF part of MPF:" );
573//              }
574//              jest.read_parameters();
575//              this2bm.set_connection ( BMs ( 0 )->_yrv(), BMs ( 0 )->_rvc(), yrv, rvc );
576//              this2pf.set_connection ( pf->_yrv(), pf->_rvc(), yrv, rvc );
577//              pf2bm.set_connection ( BMs ( 0 )->_rvc(), pf->posterior()._rv() );
578//      }
579// };
580// UIREGISTER ( MPF );
581
582/*! ARXg for estimation of state-space variances
583*/
584// class MPF_ARXg :public BM{
585//      protected:
586//      shared_ptr<PF> pf;
587//      //! pointer to Array of BMs
588//      Array<ARX*> BMso;
589//      //! pointer to Array of BMs
590//      Array<ARX*> BMsp;
591//      //!parameter evolution
592//      shared_ptr<fnc> g;
593//      //!observation function
594//      shared_ptr<fnc> h;
595//     
596//      public:
597//              void bayes(const vec &yt, const vec &cond );
598//              void from_setting(const Setting &set) ;
599//              void validate() {
600//                      bdm_assert(g->dimensionc()==g->dimension(),"not supported yet");
601//                      bdm_assert(h->dimensionc()==g->dimension(),"not supported yet");                       
602//              }
603//
604//              double logpred(const vec &cond) const NOT_IMPLEMENTED(0.0);
605//              epdf* epredictor() const NOT_IMPLEMENTED(NULL);
606//              pdf* predictor() const NOT_IMPLEMENTED(NULL);
607//              const epdf& posterior() const {return pf->posterior();};
608//             
609//              void log_register( logger &L, const string &prefix ){
610//                      BM::log_register(L,prefix);
611//                      registered_logger->ids.set_size ( 3 );
612//                      registered_logger->ids(1)= L.add_vector(RV("Q",dimension()*dimension()), prefix+L.prefix_sep()+"Q");
613//                      registered_logger->ids(2)= L.add_vector(RV("R",dimensiony()*dimensiony()), prefix+L.prefix_sep()+"R");
614//                     
615//              };
616//              void log_write() const {
617//                      BM::log_write();
618//                      mat mQ=zeros(dimension(),dimension());
619//                      mat pom=zeros(dimension(),dimension());
620//                      mat mR=zeros(dimensiony(),dimensiony());
621//                      mat pom2=zeros(dimensiony(),dimensiony());
622//                      mat dum;
623//                      const vec w=pf->posterior()._w();
624//                      for (int i=0; i<w.length(); i++){
625//                              BMsp(i)->posterior().mean_mat(dum,pom);
626//                              mQ += w(i) * pom;
627//                              BMso(i)->posterior().mean_mat(dum,pom2);
628//                              mR += w(i) * pom2;
629//                             
630//                      }
631//                      registered_logger->L.log_vector ( registered_logger->ids ( 1 ), cvectorize(mQ) );
632//                      registered_logger->L.log_vector ( registered_logger->ids ( 2 ), cvectorize(mR) );
633//                     
634//              }
635// };
636// UIREGISTER(MPF_ARXg);
637
638
639}
640#endif // KF_H
641
Note: See TracBrowser for help on using the browser.