root/library/bdm/estim/particles.h @ 971

Revision 971, 14.8 kB (checked in by smidl, 14 years ago)

New MPF

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticleBase : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29       
30        //! custom posterior - product of empirical and exact part
31        class eprod_2:public eprod_base {
32                protected:
33                MarginalizedParticleBase &mp;
34                public:
35                eprod_2(MarginalizedParticleBase &m):mp(m){}
36                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
37                const int no_factors() const {return 2;}
38        } est;
39       
40        public:
41                MarginalizedParticleBase():est(*this){};
42                MarginalizedParticleBase(const MarginalizedParticleBase &m2):est(*this){
43                        bm = m2.bm->_copy();
44                        est_emp = m2.est_emp;
45                        validate();
46                };
47                void bayes(const vec &dt, const vec &cond) NOT_IMPLEMENTED_VOID;
48               
49                const eprod_2& posterior() const {return est;}
50       
51                void set_prior(const epdf *pdf0){
52                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
53                        if (ep){ // full prior
54                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
55                                bm->set_prior(ep->factor(0));
56                                est_emp.set_point(ep->factor(1)->sample());
57                        } else {
58                                // assume prior is only for emp;
59                                est_emp.set_point(pdf0->sample());
60                        }
61                }
62                void from_setting(const Setting &set){
63                        BM::from_setting ( set ); 
64                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
65                }
66                void validate() {
67                        BM::validate();
68                        bdm_assert(bm,"Internal BM is not given");
69                }
70};
71
72class MarginalizedParticle : public MarginalizedParticleBase{
73        protected:
74                //! pdf with for transitional par
75                shared_ptr<pdf> par; // pdf for non-linear part
76                //! link from this to bm
77                shared_ptr<datalink_part> cond2bm;
78                //! link from cond to par
79                shared_ptr<datalink_part> cond2par;
80                //! link from emp 2 par
81                shared_ptr<datalink_part> emp2bm;
82                //! link from emp 2 par
83                shared_ptr<datalink_part> emp2par;
84               
85        public:
86                BM* _copy() const{return new MarginalizedParticle(*this);};
87                void bayes(const vec &dt, const vec &cond){
88                        vec par_cond(par->dimensionc());
89                        cond2par->filldown(cond,par_cond); // copy ut
90                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
91                       
92                        //sample new particle
93                        est_emp.set_point(par->samplecond(par_cond));
94                        //if (evalll)
95                        vec bm_cond(bm->dimensionc());
96                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
97                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
98                        bm->bayes(dt, bm_cond);
99                        ll=bm->_ll();
100                }
101                               
102                /*! parse structure
103                \code
104                class = "MarginalizedParticle";
105                parameter_pdf = {class = 'epdf_offspring', ...};
106                bm = {class = 'bm_offspring',...};
107                \endcode
108                If rvs are set, then it checks for compatibility.
109                */
110                void from_setting(const Setting &set){
111                        MarginalizedParticleBase::from_setting ( set );
112                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
113                }
114               
115                void to_setting(Setting &set){
116                        MarginalizedParticleBase::to_setting(set);
117                        UI::save(par,set,"parameter_pdf");
118                }
119                void validate(){
120                        if (est_emp.point.length()!=par->dimension())
121                                est_emp.set_point(zeros(par->dimension()));
122                        est.validate();
123                       
124                        yrv = bm->_yrv();
125                        dimy = bm->dimensiony();
126                        set_rv( concat(bm->_rv(), par->_rv()));
127                        set_dim( par->dimension()+bm->dimension());
128                       
129                        rvc = par->_rvc();
130                        rvc.add(bm->_rvc());
131                        rvc=rvc.subt(par->_rv());
132                        rvc=rvc.subt(par->_rv().copy_t(-1));
133                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
134                       
135                        cond2bm=new datalink_part;
136                        cond2par=new datalink_part;
137                        emp2bm  =new datalink_part;
138                        emp2par =new datalink_part;
139                        cond2bm->set_connection(bm->_rvc(), rvc);
140                        cond2par->set_connection(par->_rvc(), rvc);
141                        emp2bm->set_connection(bm->_rvc(), par->_rv());
142                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
143                       
144                        dimc = rvc._dsize();
145                };
146};
147UIREGISTER(MarginalizedParticle);
148
149//! class used in PF
150class BootstrapParticle : public BM{
151        dirac est;
152        shared_ptr<pdf> par;
153        shared_ptr<pdf> obs;
154        shared_ptr<datalink_part> cond2par;
155        shared_ptr<datalink_part> cond2obs;
156        shared_ptr<datalink_part> xt2obs;
157        shared_ptr<datalink_part> xtm2par;
158        public:
159                BM* _copy() const{return new BootstrapParticle(*this);};
160                void bayes(const vec &dt, const vec &cond){
161                        vec par_cond(par->dimensionc());
162                        cond2par->filldown(cond,par_cond); // copy ut
163                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
164                       
165                        //sample new particle
166                        est.set_point(par->samplecond(par_cond));
167                        //if (evalll)
168                        vec obs_cond(obs->dimensionc());
169                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
170                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
171                        ll=obs->evallogcond(dt,obs_cond);
172                }
173                const dirac& posterior() const {return est;}
174               
175                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
176               
177                /*! parse structure
178                \code
179                class = "BootstrapParticle";
180                parameter_pdf = {class = 'epdf_offspring', ...};
181                observation_pdf = {class = 'epdf_offspring',...};
182                \endcode
183                If rvs are set, then it checks for compatibility.
184                */
185                void from_setting(const Setting &set){
186                        BM::from_setting ( set );
187                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
188                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
189                }
190                void validate(){
191                        yrv = obs->_rv();
192                        dimy = obs->dimension();
193                        set_rv( par->_rv());
194                        set_dim( par->dimension());
195                       
196                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
197                        rvc.add(obs->_rvc()); //
198                       
199                        cond2obs=new datalink_part;
200                        cond2par=new datalink_part;
201                        xt2obs  =new datalink_part;
202                        xtm2par =new datalink_part;
203                        cond2obs->set_connection(obs->_rvc(), rvc);
204                        cond2par->set_connection(par->_rvc(), rvc);
205                        xt2obs->set_connection(obs->_rvc(), _rv());
206                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
207                       
208                        dimc = rvc._dsize();
209                };
210};
211UIREGISTER(BootstrapParticle);
212
213
214/*!
215* \brief Trivial particle filter with proposal density equal to parameter evolution model.
216
217Posterior density is represented by a weighted empirical density (\c eEmp ).
218*/
219
220class PF : public BM {
221        //! \var log_level_enums weights
222        //! all weightes will be logged
223
224        //! \var log_level_enums menas
225        //! means of particles will be logged
226        LOG_LEVEL(PF,logweights,logmeans,logvars);
227       
228        class pf_mix: public emix_base{
229                Array<BM*> &bms;
230                public:
231                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
232                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
233                        int no_coms() const {return bms.length();}
234        };
235protected:
236        //!number of particles;
237        int n;
238        //!posterior density
239        pf_mix est;
240        //! weights;
241        vec w;
242        //! particles
243        Array<BM*> particles;
244        //! internal structure storing loglikelihood of predictions
245        vec lls;
246
247        //! which resampling method will be used
248        RESAMPLING_METHOD resmethod;
249        //! resampling threshold; in this case its meaning is minimum ratio of active particles
250        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
251        double res_threshold;
252
253        //! \name Options
254        //!@{
255        //!@}
256
257public:
258        //! \name Constructors
259        //!@{
260        PF ( ) : est(w,particles) { };
261
262        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
263                n = n0;
264                res_threshold = res_th0;
265                resmethod = rm;
266        };
267        void set_model ( const BM *particle0, const epdf *prior) {
268                if (n>0){
269                        particles.set_length(n);
270                        for (int i=0; i<n;i++){
271                                particles(i) = particle0->_copy();
272                                particles(i)->set_prior(prior);
273                        }
274                }
275                // set values for posterior
276                est.set_rv ( particle0->posterior()._rv() );
277        };
278        void set_statistics ( const vec w0, const epdf &epdf0 ) {
279                //est.set_statistics ( w0, epdf0 );
280        };
281/*      void set_statistics ( const eEmp &epdf0 ) {
282                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
283                est = epdf0;
284        };*/
285        //!@}
286
287        //! bayes compute weights of the
288        virtual void bayes_weights();
289        //! important part of particle filtering - decide if it is time to perform resampling
290        virtual bool do_resampling() {
291                double eff = 1.0 / ( w * w );
292                return eff < ( res_threshold*n );
293        }
294        void bayes ( const vec &yt, const vec &cond );
295        //!access function
296        vec& _lls() {
297                return lls;
298        }
299        //!access function
300        RESAMPLING_METHOD _resmethod() const {
301                return resmethod;
302        }
303        //! return correctly typed posterior (covariant return)
304        const pf_mix& posterior() const {
305                return est;
306        }
307
308        /*! configuration structure for basic PF
309        \code
310        parameter_pdf   = pdf_class;         // parameter evolution pdf
311        observation_pdf = pdf_class;         // observation pdf
312        prior           = epdf_class;         // prior probability density
313        --- optional ---
314        n               = 10;                 // number of particles
315        resmethod       = 'systematic', or 'multinomial', or 'stratified'
316                                                                                  // resampling method
317        res_threshold   = 0.5;                // resample when active particles drop below 50%
318        \endcode
319        */
320        void from_setting ( const Setting &set ) {
321                BM::from_setting ( set );
322                UI::get ( log_level, set, "log_level", UI::optional );
323               
324                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
325               
326                shared_ptr<epdf> pri = UI::build<epdf> ( set, "prior", UI::compulsory );
327                n =0;
328                UI::get(n,set,"n",UI::optional);;
329                if (n>0){
330                        particles.set_length(n);
331                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
332                        w = ones(n)/n;
333                }
334                set_prior(pri.get());
335                // set resampling method
336                resmethod_from_set ( set );
337                //set drv
338
339                rvc = bm0->_rvc();
340                dimc = bm0->dimensionc();
341                BM::set_rv(bm0->_rv());
342                yrv=bm0->_yrv();
343                dimy = bm0->dimensiony();
344        }
345       
346        void log_register ( bdm::logger& L, const string& prefix ){
347                BM::log_register(L,prefix);
348                if (log_level[logweights]){
349                        L.add_vector( log_level, logweights, RV ( particles.length()), prefix); 
350                }
351                if (log_level[logmeans]){
352                        for (int i=0; i<particles.length(); i++){
353                                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
354                        }
355                }
356                if (log_level[logvars]){
357                        for (int i=0; i<particles.length(); i++){
358                                L.add_vector( log_level, logvars, RV ( particles(i)->dimension() ), prefix , i);
359                        }
360                }
361        };
362        void log_write ( ) const {
363                BM::log_write();
364                if (log_level[logweights]){
365                        log_level.store( logweights, w); 
366                }
367                if (log_level[logmeans]){
368                        for (int i=0; i<particles.length(); i++){
369                                log_level.store( logmeans, particles(i)->posterior().mean(), i);
370                        }
371                }
372                if (log_level[logvars]){
373                        for (int i=0; i<particles.length(); i++){
374                                log_level.store( logvars, particles(i)->posterior().variance(), i);
375                        }
376                }
377               
378        }
379       
380        void set_prior(const epdf *pri){
381                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
382                if (emi) {
383                        bdm_assert(particles.length()>0, "initial particle is not assigned");
384                        n = emi->_w().length();
385                        int old_n = particles.length();
386                        if (n!=old_n){
387                                particles.set_length(n,true);
388                        } 
389                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
390                       
391                        for (int i =0; i<n; i++){
392                                particles(i)->set_prior(emi->_com(i));
393                        }
394                } else {
395                        // try to find "n"
396                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
397                        for (int i =0; i<n; i++){
398                                particles(i)->set_prior(pri);
399                        }
400                       
401                }
402        }
403        //! auxiliary function reading parameter 'resmethod' from configuration file
404        void resmethod_from_set ( const Setting &set ) {
405                string resmeth;
406                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
407                        if ( resmeth == "systematic" ) {
408                                resmethod = SYSTEMATIC;
409                        } else  {
410                                if ( resmeth == "multinomial" ) {
411                                        resmethod = MULTINOMIAL;
412                                } else {
413                                        if ( resmeth == "stratified" ) {
414                                                resmethod = STRATIFIED;
415                                        } else {
416                                                bdm_error ( "Unknown resampling method" );
417                                        }
418                                }
419                        }
420                } else {
421                        resmethod = SYSTEMATIC;
422                };
423                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
424                        res_threshold = 0.9;
425                }
426                //validate();
427        }
428
429        void validate() {
430                BM::validate();
431                est.validate();
432                bdm_assert ( n>0, "empty particle pool" );
433                n = w.length();
434                lls = zeros ( n );
435
436                if ( particles(0)->_rv()._dsize() > 0 ) {
437                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV and dimension of posterior" );
438                }
439        }
440        //! resample posterior density (from outside - see MPF)
441        void resample ( ) {
442                ivec ind = zeros_i ( n );
443                bdm::resample(w,ind,resmethod);
444                // copy the internals according to ind
445                for (int i = 0; i < n; i++ ) {
446                        if ( ind ( i ) != i ) {
447                                particles( i ) = particles( ind ( i ) )->_copy();
448                        }
449                        w ( i ) = 1.0 / n;
450                }
451        }
452        //! access function
453        Array<BM*>& _particles() {
454                return particles;
455        }
456
457};
458UIREGISTER ( PF );
459
460/*! Marginalized particle for state-space models with unknown parameters of residues distribution
461
462\f[
463\begin{eqnarray}
464x_t = g(x_{t-1}) + v_t,\\
465z_t = h(x_{t-1}) + w_t,\\
466\end{eqnarray}
467\f]
468
469This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.
470
471*/
472class NoiseParticle : public MarginalizedParticleBase{
473        protected:
474                //! function transforming xt, ut -> x_t+1
475                shared_ptr<fnc> g; // pdf for non-linear part
476                //! function transforming xt,ut -> yt
477                shared_ptr<fnc> h; // pdf for non-linear part
478
479                RV rvv;
480               
481                RV rvx;
482                RV rvxc;
483                RV rvyc;
484               
485                //!link from condition to f
486                datalink_part cond2g;
487                //!link from condition to h
488                datalink_part cond2h;
489                //!link from xt to f
490                datalink_part x2g;
491                //!link from xt to h
492                datalink_part x2h;
493               
494        public:
495                BM* _copy() const{return new NoiseParticle();};
496                void bayes(const vec &dt, const vec &cond){
497                        epdf* pred_vw=bm->epredictor();
498                        shared_ptr<epdf> pred_v = pred_vw->marginal(rvv);
499                       
500                        vec vt=pred_v->sample();
501                       
502                        //new sample
503                        vec &xtm=est_emp.point;
504                        vec g_args(g->dimensionc());
505                        x2g.filldown(xtm,g_args);
506                        cond2g.filldown(cond,g_args);
507                        vec xt = g->eval(g_args) + vt;
508                       
509                        // residue of observation
510                        vec h_args(h->dimensionc());
511                        x2h.filldown(xt,h_args);
512                        cond2h.filldown(cond,h_args);
513                        vec wt = dt-h->eval(h_args);
514                        // the vector [v_t,w_t] is now complete
515                        bm->bayes(concat(vt,wt));
516                        ll=bm->_ll();
517                }
518                void from_setting(const Setting &set){
519                        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
520                               
521                        UI::get(g,set,"g",UI::compulsory);
522                        UI::get(h,set,"h",UI::compulsory);
523                        RV rvx;
524                        UI::get(rvx,set,"rvx",UI::compulsory);
525                        est_emp.set_rv(rvx);
526                       
527                        RV rvxc;
528                        UI::get(rvxc,set,"rvxc",UI::compulsory);
529                        RV rvyc;
530                        UI::get(rvyc,set,"rvyc",UI::compulsory);
531                }
532                void validate(){
533                        MarginalizedParticleBase::validate();
534                        //check dimensions
535                        rvc = rvxc;
536                        rvc.add( rvyc);
537                        rvc.subt(rvx);
538               
539                        //establish datalinks
540                        x2g.set_connection(rvxc, rvx.copy_t(-1));
541                        cond2g.set_connection(rvxc, rvc);
542                       
543                        x2h.set_connection(rvyc, rvx);
544                        cond2h.set_connection(rvyc, rvc);
545                }
546};
547UIREGISTER(NoiseParticle);
548
549
550
551}
552#endif // KF_H
553
Note: See TracBrowser for help on using the browser.