root/library/bdm/estim/particles.h @ 964

Revision 964, 18.6 kB (checked in by smidl, 14 years ago)

Corrections in ARX and PF

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticle : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29        //! pdf with for transitional par
30        shared_ptr<pdf> par; // pdf for non-linear part
31        //! link from this to bm
32        shared_ptr<datalink_part> cond2bm;
33        //! link from cond to par
34        shared_ptr<datalink_part> cond2par;
35        //! link from emp 2 par
36        shared_ptr<datalink_part> emp2bm;
37        //! link from emp 2 par
38        shared_ptr<datalink_part> emp2par;
39       
40        //! custom posterior - product
41        class eprod_2:public eprod_base {
42                protected:
43                MarginalizedParticle &mp;
44                public:
45                eprod_2(MarginalizedParticle &m):mp(m){}
46                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
47                const int no_factors() const {return 2;}
48        } est;
49       
50        public:
51                MarginalizedParticle():est(*this){};
52                MarginalizedParticle(const MarginalizedParticle &m2):est(*this){
53                        bm = m2.bm->_copy();
54                        est_emp = m2.est_emp;
55                        par = m2.par;
56                        validate();
57                };
58                BM* _copy() const{return new MarginalizedParticle(*this);};
59                void bayes(const vec &dt, const vec &cond){
60                        vec par_cond(par->dimensionc());
61                        cond2par->filldown(cond,par_cond); // copy ut
62                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
63                       
64                        //sample new particle
65                        est_emp.set_point(par->samplecond(par_cond));
66                        //if (evalll)
67                        vec bm_cond(bm->dimensionc());
68                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
69                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
70                        bm->bayes(dt, bm_cond);
71                        ll=bm->_ll();
72                }
73                const eprod_2& posterior() const {return est;}
74       
75                void set_prior(const epdf *pdf0){
76                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
77                        if (ep){ // full prior
78                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
79                                bm->set_prior(ep->factor(0));
80                                est_emp.set_point(ep->factor(1)->sample());
81                        } else {
82                                // assume prior is only for emp;
83                                est_emp.set_point(pdf0->sample());
84                        }
85                }
86
87                /*! parse structure
88                \code
89                class = "BootstrapParticle";
90                parameter_pdf = {class = 'epdf_offspring', ...};
91                observation_pdf = {class = 'epdf_offspring',...};
92                \endcode
93                If rvs are set, then it checks for compatibility.
94                */
95                void from_setting(const Setting &set){
96                        BM::from_setting ( set );
97                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
98                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
99                }
100               
101                void to_setting(const Setting &set){
102                        if (BM::log_level[logfull]){
103                        }
104                }
105                void validate(){
106                        if (est_emp.point.length()!=par->dimension())
107                                est_emp.set_point(zeros(par->dimension()));
108                        est.validate();
109                       
110                        yrv = bm->_yrv();
111                        dimy = bm->dimensiony();
112                        set_rv( concat(bm->_rv(), par->_rv()));
113                        set_dim( par->dimension()+bm->dimension());
114
115                        rvc = par->_rvc();
116                        rvc.add(bm->_rvc());
117                        rvc=rvc.subt(par->_rv());
118                        rvc=rvc.subt(par->_rv().copy_t(-1));
119                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
120                       
121                        cond2bm=new datalink_part;
122                        cond2par=new datalink_part;
123                        emp2bm  =new datalink_part;
124                        emp2par =new datalink_part;
125                        cond2bm->set_connection(bm->_rvc(), rvc);
126                        cond2par->set_connection(par->_rvc(), rvc);
127                        emp2bm->set_connection(bm->_rvc(), par->_rv());
128                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
129                       
130                        dimc = rvc._dsize();
131                };
132};
133UIREGISTER(MarginalizedParticle);
134
135//! class used in PF
136class BootstrapParticle : public BM{
137        dirac est;
138        shared_ptr<pdf> par;
139        shared_ptr<pdf> obs;
140        shared_ptr<datalink_part> cond2par;
141        shared_ptr<datalink_part> cond2obs;
142        shared_ptr<datalink_part> xt2obs;
143        shared_ptr<datalink_part> xtm2par;
144        public:
145                BM* _copy() const{return new BootstrapParticle(*this);};
146                void bayes(const vec &dt, const vec &cond){
147                        vec par_cond(par->dimensionc());
148                        cond2par->filldown(cond,par_cond); // copy ut
149                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
150                       
151                        //sample new particle
152                        est.set_point(par->samplecond(par_cond));
153                        //if (evalll)
154                        vec obs_cond(obs->dimensionc());
155                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
156                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
157                        ll=obs->evallogcond(dt,obs_cond);
158                }
159                const dirac& posterior() const {return est;}
160               
161                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
162               
163                /*! parse structure
164                \code
165                class = "BootstrapParticle";
166                parameter_pdf = {class = 'epdf_offspring', ...};
167                observation_pdf = {class = 'epdf_offspring',...};
168                \endcode
169                If rvs are set, then it checks for compatibility.
170                */
171                void from_setting(const Setting &set){
172                        BM::from_setting ( set );
173                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
174                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
175                }
176                void validate(){
177                        yrv = obs->_rv();
178                        dimy = obs->dimension();
179                        set_rv( par->_rv());
180                        set_dim( par->dimension());
181                       
182                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
183                        rvc.add(obs->_rvc()); //
184                       
185                        cond2obs=new datalink_part;
186                        cond2par=new datalink_part;
187                        xt2obs  =new datalink_part;
188                        xtm2par =new datalink_part;
189                        cond2obs->set_connection(obs->_rvc(), rvc);
190                        cond2par->set_connection(par->_rvc(), rvc);
191                        xt2obs->set_connection(obs->_rvc(), _rv());
192                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
193                       
194                        dimc = rvc._dsize();
195                };
196};
197UIREGISTER(BootstrapParticle);
198
199
200/*!
201* \brief Trivial particle filter with proposal density equal to parameter evolution model.
202
203Posterior density is represented by a weighted empirical density (\c eEmp ).
204*/
205
206class PF : public BM {
207        //! \var log_level_enums weights
208        //! all weightes will be logged
209
210        //! \var log_level_enums menas
211        //! means of particles will be logged
212        LOG_LEVEL(PF,logweights,logmeans,logvars);
213       
214        class pf_mix: public emix_base{
215                Array<BM*> &bms;
216                public:
217                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
218                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
219                        int no_coms() const {return bms.length();}
220        };
221protected:
222        //!number of particles;
223        int n;
224        //!posterior density
225        pf_mix est;
226        //! weights;
227        vec w;
228        //! particles
229        Array<BM*> particles;
230        //! internal structure storing loglikelihood of predictions
231        vec lls;
232
233        //! which resampling method will be used
234        RESAMPLING_METHOD resmethod;
235        //! resampling threshold; in this case its meaning is minimum ratio of active particles
236        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
237        double res_threshold;
238
239        //! \name Options
240        //!@{
241        //!@}
242
243public:
244        //! \name Constructors
245        //!@{
246        PF ( ) : est(w,particles) { };
247
248        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
249                n = n0;
250                res_threshold = res_th0;
251                resmethod = rm;
252        };
253        void set_model ( const BM *particle0, const epdf *prior) {
254                if (n>0){
255                        particles.set_length(n);
256                        for (int i=0; i<n;i++){
257                                particles(i) = particle0->_copy();
258                                particles(i)->set_prior(prior);
259                        }
260                }
261                // set values for posterior
262                est.set_rv ( particle0->posterior()._rv() );
263        };
264        void set_statistics ( const vec w0, const epdf &epdf0 ) {
265                //est.set_statistics ( w0, epdf0 );
266        };
267/*      void set_statistics ( const eEmp &epdf0 ) {
268                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
269                est = epdf0;
270        };*/
271        //!@}
272
273        //! bayes compute weights of the
274        virtual void bayes_weights();
275        //! important part of particle filtering - decide if it is time to perform resampling
276        virtual bool do_resampling() {
277                double eff = 1.0 / ( w * w );
278                return eff < ( res_threshold*n );
279        }
280        void bayes ( const vec &yt, const vec &cond );
281        //!access function
282        vec& _lls() {
283                return lls;
284        }
285        //!access function
286        RESAMPLING_METHOD _resmethod() const {
287                return resmethod;
288        }
289        //! return correctly typed posterior (covariant return)
290        const pf_mix& posterior() const {
291                return est;
292        }
293
294        /*! configuration structure for basic PF
295        \code
296        parameter_pdf   = pdf_class;         // parameter evolution pdf
297        observation_pdf = pdf_class;         // observation pdf
298        prior           = epdf_class;         // prior probability density
299        --- optional ---
300        n               = 10;                 // number of particles
301        resmethod       = 'systematic', or 'multinomial', or 'stratified'
302                                                                                  // resampling method
303        res_threshold   = 0.5;                // resample when active particles drop below 50%
304        \endcode
305        */
306        void from_setting ( const Setting &set ) {
307                BM::from_setting ( set );
308                UI::get ( log_level, set, "log_level", UI::optional );
309               
310                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
311               
312                shared_ptr<epdf> pri = UI::build<epdf> ( set, "prior", UI::compulsory );
313                n =0;
314                UI::get(n,set,"n",UI::optional);;
315                if (n>0){
316                        particles.set_length(n);
317                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
318                        w = ones(n)/n;
319                }
320                set_prior(pri.get());
321                // set resampling method
322                resmethod_from_set ( set );
323                //set drv
324
325                rvc = bm0->_rvc();
326                dimc = bm0->dimensionc();
327                BM::set_rv(bm0->_rv());
328                yrv=bm0->_yrv();
329                dimy = bm0->dimensiony();
330        }
331       
332        void log_register ( bdm::logger& L, const string& prefix ){
333                BM::log_register(L,prefix);
334                if (log_level[logweights]){
335                        L.add_vector( log_level, logweights, RV ( particles.length()), prefix); 
336                }
337                if (log_level[logmeans]){
338                        for (int i=0; i<particles.length(); i++){
339                                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
340                        }
341                }
342                if (log_level[logvars]){
343                        for (int i=0; i<particles.length(); i++){
344                                L.add_vector( log_level, logvars, RV ( particles(i)->dimension() ), prefix , i);
345                        }
346                }
347        };
348        void log_write ( ) const {
349                BM::log_write();
350                if (log_level[logweights]){
351                        log_level.store( logweights, w); 
352                }
353                if (log_level[logmeans]){
354                        for (int i=0; i<particles.length(); i++){
355                                log_level.store( logmeans, particles(i)->posterior().mean(), i);
356                        }
357                }
358                if (log_level[logvars]){
359                        for (int i=0; i<particles.length(); i++){
360                                log_level.store( logvars, particles(i)->posterior().variance(), i);
361                        }
362                }
363               
364        }
365       
366        void set_prior(const epdf *pri){
367                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
368                if (emi) {
369                        bdm_assert(particles.length()>0, "initial particle is not assigned");
370                        n = emi->_w().length();
371                        int old_n = particles.length();
372                        if (n!=old_n){
373                                particles.set_length(n,true);
374                        } 
375                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
376                       
377                        for (int i =0; i<n; i++){
378                                particles(i)->set_prior(emi->_com(i));
379                        }
380                } else {
381                        // try to find "n"
382                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
383                        for (int i =0; i<n; i++){
384                                particles(i)->set_prior(pri);
385                        }
386                       
387                }
388        }
389        //! auxiliary function reading parameter 'resmethod' from configuration file
390        void resmethod_from_set ( const Setting &set ) {
391                string resmeth;
392                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
393                        if ( resmeth == "systematic" ) {
394                                resmethod = SYSTEMATIC;
395                        } else  {
396                                if ( resmeth == "multinomial" ) {
397                                        resmethod = MULTINOMIAL;
398                                } else {
399                                        if ( resmeth == "stratified" ) {
400                                                resmethod = STRATIFIED;
401                                        } else {
402                                                bdm_error ( "Unknown resampling method" );
403                                        }
404                                }
405                        }
406                } else {
407                        resmethod = SYSTEMATIC;
408                };
409                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
410                        res_threshold = 0.9;
411                }
412                //validate();
413        }
414
415        void validate() {
416                BM::validate();
417                est.validate();
418                bdm_assert ( n>0, "empty particle pool" );
419                n = w.length();
420                lls = zeros ( n );
421
422                if ( particles(0)->_rv()._dsize() > 0 ) {
423                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV and dimension of posterior" );
424                }
425        }
426        //! resample posterior density (from outside - see MPF)
427        void resample ( ) {
428                ivec ind = zeros_i ( n );
429                bdm::resample(w,ind,resmethod);
430                // copy the internals according to ind
431                for (int i = 0; i < n; i++ ) {
432                        if ( ind ( i ) != i ) {
433                                particles( i ) = particles( ind ( i ) )->_copy();
434                        }
435                        w ( i ) = 1.0 / n;
436                }
437        }
438        //! access function
439        Array<BM*>& _particles() {
440                return particles;
441        }
442
443};
444UIREGISTER ( PF );
445
446/*!
447\brief Marginalized Particle filter
448
449A composition of particle filter with exact (or almost exact) bayesian models (BMs).
450The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFpdf.
451*/
452
453// class MPF : public BM  {
454//      //! Introduces new option
455//      //! \li means - meaning TODO
456//      LOG_LEVEL(MPF,means);
457// protected:
458//      //! particle filter on non-linear variable
459//      shared_ptr<PF> pf;
460//      //! Array of Bayesian models
461//      Array<BM*> BMs;
462//
463//      //! internal class for pdf providing composition of eEmp with external components
464//
465//      class mpfepdf : public epdf  {
466//              //! pointer to particle filter
467//              shared_ptr<PF> &pf;
468//              //! pointer to Array of BMs
469//              Array<BM*> &BMs;
470//      public:
471//              //! constructor
472//              mpfepdf ( shared_ptr<PF> &pf0, Array<BM*> &BMs0 ) : epdf(), pf ( pf0 ), BMs ( BMs0 ) { };
473//              //! a variant of set parameters - this time, parameters are read from BMs and pf
474//              void read_parameters() {
475//                      rv = concat ( pf->posterior()._rv(), BMs ( 0 )->posterior()._rv() );
476//                      dim = pf->posterior().dimension() + BMs ( 0 )->posterior().dimension();
477//                      bdm_assert_debug ( dim == rv._dsize(), "Wrong name " );
478//              }
479//              vec mean() const;
480//
481//              vec variance() const;
482//
483//              void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const;
484//
485//              vec sample() const NOT_IMPLEMENTED(0);
486//
487//              double evallog ( const vec &val ) const NOT_IMPLEMENTED(0);             
488//      };
489//
490//      //! Density joining PF.est with conditional parts
491//      mpfepdf jest;
492//
493//      //! datalink from global yt and cond (Up) to BMs yt and cond (Down)
494//      datalink_m2m this2bm;
495//      //! datalink from global yt and cond (Up) to PFs yt and cond (Down)
496//      datalink_m2m this2pf;
497//      //!datalink from PF part to BM
498//      datalink_part pf2bm;
499//
500// public:
501//      //! Default constructor.
502//      MPF () :  jest ( pf, BMs ) {};
503//      //! set all parameters at once
504//      void set_pf ( shared_ptr<pdf> par0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
505//              if (!pf) pf=new PF;
506//              pf->set_model ( par0, par0 ); // <=== nasty!!!
507//              pf->set_parameters ( n0, rm );
508//              pf->set_rv(par0->_rv());
509//              BMs.set_length ( n0 );
510//      }
511//      //! set a prototype of BM, copy it to as many times as there is particles in pf
512//      void set_BM ( const BM &BMcond0 ) {
513//
514//              int n = pf->__w().length();
515//              BMs.set_length ( n );
516//              // copy
517//              //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
518//              for ( int i = 0; i < n; i++ ) {
519//                      BMs ( i ) = (BM*) BMcond0._copy();
520//              }
521//      };
522//
523//      void bayes ( const vec &yt, const vec &cond );
524//
525//      const epdf& posterior() const {
526//              return jest;
527//      }
528//
529//      //!Access function
530//      const BM* _BM ( int i ) {
531//              return BMs ( i );
532//      }
533//      PF& _pf() {return *pf;}
534//
535//
536//      virtual double logpred ( const vec &yt ) const NOT_IMPLEMENTED(0);
537//             
538//      virtual epdf* epredictor() const NOT_IMPLEMENTED(NULL);
539//     
540//      virtual pdf* predictor() const NOT_IMPLEMENTED(NULL);
541//
542//
543//      /*! configuration structure for basic PF
544//      \code
545//      BM              = BM_class;           // Bayesian filtr for analytical part of the model
546//      parameter_pdf   = pdf_class;         // transitional pdf for non-parametric part of the model
547//      prior           = epdf_class;         // prior probability density
548//      --- optional ---
549//      n               = 10;                 // number of particles
550//      resmethod       = 'systematic', or 'multinomial', or 'stratified'
551//                                                                                // resampling method
552//      \endcode
553//      */
554//      void from_setting ( const Setting &set ) {
555//              BM::from_setting( set );
556//
557//              shared_ptr<pdf> par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
558//
559//              pf = new PF;
560//              // prior must be set before BM
561//              pf->prior_from_set ( set );
562//              pf->resmethod_from_set ( set );
563//              pf->set_model ( par, par ); // too hackish!
564//
565//              shared_ptr<BM> BM0 = UI::build<BM> ( set, "BM", UI::compulsory );
566//              set_BM ( *BM0 );
567//
568//              //set drv
569//              //??set_yrv(concat(BM0->_yrv(),u) );
570//              set_yrv ( BM0->_yrv() );
571//              rvc = BM0->_rvc().subt ( par->_rv() );
572//              //find potential input - what remains in rvc when we subtract rv
573//              RV u = par->_rvc().subt ( par->_rv().copy_t ( -1 ) );
574//              rvc.add ( u );
575//              dimc = rvc._dsize();
576//              validate();
577//      }
578//
579//      void validate() {
580//              BM::validate();
581//              try {
582//                      pf->validate();
583//              } catch ( std::exception ) {
584//                      throw UIException ( "Error in PF part of MPF:" );
585//              }
586//              jest.read_parameters();
587//              this2bm.set_connection ( BMs ( 0 )->_yrv(), BMs ( 0 )->_rvc(), yrv, rvc );
588//              this2pf.set_connection ( pf->_yrv(), pf->_rvc(), yrv, rvc );
589//              pf2bm.set_connection ( BMs ( 0 )->_rvc(), pf->posterior()._rv() );
590//      }
591// };
592// UIREGISTER ( MPF );
593
594/*! ARXg for estimation of state-space variances
595*/
596// class MPF_ARXg :public BM{
597//      protected:
598//      shared_ptr<PF> pf;
599//      //! pointer to Array of BMs
600//      Array<ARX*> BMso;
601//      //! pointer to Array of BMs
602//      Array<ARX*> BMsp;
603//      //!parameter evolution
604//      shared_ptr<fnc> g;
605//      //!observation function
606//      shared_ptr<fnc> h;
607//     
608//      public:
609//              void bayes(const vec &yt, const vec &cond );
610//              void from_setting(const Setting &set) ;
611//              void validate() {
612//                      bdm_assert(g->dimensionc()==g->dimension(),"not supported yet");
613//                      bdm_assert(h->dimensionc()==g->dimension(),"not supported yet");                       
614//              }
615//
616//              double logpred(const vec &cond) const NOT_IMPLEMENTED(0.0);
617//              epdf* epredictor() const NOT_IMPLEMENTED(NULL);
618//              pdf* predictor() const NOT_IMPLEMENTED(NULL);
619//              const epdf& posterior() const {return pf->posterior();};
620//             
621//              void log_register( logger &L, const string &prefix ){
622//                      BM::log_register(L,prefix);
623//                      registered_logger->ids.set_size ( 3 );
624//                      registered_logger->ids(1)= L.add_vector(RV("Q",dimension()*dimension()), prefix+L.prefix_sep()+"Q");
625//                      registered_logger->ids(2)= L.add_vector(RV("R",dimensiony()*dimensiony()), prefix+L.prefix_sep()+"R");
626//                     
627//              };
628//              void log_write() const {
629//                      BM::log_write();
630//                      mat mQ=zeros(dimension(),dimension());
631//                      mat pom=zeros(dimension(),dimension());
632//                      mat mR=zeros(dimensiony(),dimensiony());
633//                      mat pom2=zeros(dimensiony(),dimensiony());
634//                      mat dum;
635//                      const vec w=pf->posterior()._w();
636//                      for (int i=0; i<w.length(); i++){
637//                              BMsp(i)->posterior().mean_mat(dum,pom);
638//                              mQ += w(i) * pom;
639//                              BMso(i)->posterior().mean_mat(dum,pom2);
640//                              mR += w(i) * pom2;
641//                             
642//                      }
643//                      registered_logger->L.log_vector ( registered_logger->ids ( 1 ), cvectorize(mQ) );
644//                      registered_logger->L.log_vector ( registered_logger->ids ( 2 ), cvectorize(mR) );
645//                     
646//              }
647// };
648// UIREGISTER(MPF_ARXg);
649
650
651}
652#endif // KF_H
653
Note: See TracBrowser for help on using the browser.