root/library/bdm/estim/particles.h @ 974

Revision 974, 15.7 kB (checked in by smidl, 14 years ago)

Noise Particle

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21       
22//! class used in PF
23class MarginalizedParticleBase : public BM{
24        protected:
25        //! discrte particle
26        dirac est_emp;
27        //! internal Bayes Model
28        shared_ptr<BM> bm; 
29       
30        //! custom posterior - product of empirical and exact part
31        class eprod_2:public eprod_base {
32                protected:
33                MarginalizedParticleBase &mp;
34                public:
35                eprod_2(MarginalizedParticleBase &m):mp(m){}
36                const epdf* factor(int i) const {return (i==0) ? &mp.bm->posterior() : &mp.est_emp;}
37                const int no_factors() const {return 2;}
38        } est;
39       
40        public:
41                MarginalizedParticleBase():est(*this){};
42                MarginalizedParticleBase(const MarginalizedParticleBase &m2):est(*this){
43                        bm = m2.bm->_copy();
44                        est_emp = m2.est_emp;
45                        est.validate();;
46                        validate();
47                };
48                void bayes(const vec &dt, const vec &cond) NOT_IMPLEMENTED_VOID;
49               
50                const eprod_2& posterior() const {return est;}
51       
52                void set_prior(const epdf *pdf0){
53                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
54                        if (ep){ // full prior
55                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
56                                bm->set_prior(ep->factor(0));
57                                est_emp.set_point(ep->factor(1)->sample());
58                        } else {
59                                // assume prior is only for emp;
60                                est_emp.set_point(pdf0->sample());
61                        }
62                }
63                void from_setting(const Setting &set){
64                        BM::from_setting ( set ); 
65                        bm = UI::build<BM> ( set, "bm", UI::compulsory );
66                }
67                void validate() {
68                        BM::validate();
69                        //est.validate(); --pdfs not known
70                        bdm_assert(bm,"Internal BM is not given");
71                }
72};
73
74class MarginalizedParticle : public MarginalizedParticleBase{
75        protected:
76                //! pdf with for transitional par
77                shared_ptr<pdf> par; // pdf for non-linear part
78                //! link from this to bm
79                shared_ptr<datalink_part> cond2bm;
80                //! link from cond to par
81                shared_ptr<datalink_part> cond2par;
82                //! link from emp 2 par
83                shared_ptr<datalink_part> emp2bm;
84                //! link from emp 2 par
85                shared_ptr<datalink_part> emp2par;
86               
87        public:
88                BM* _copy() const{return new MarginalizedParticle(*this);};
89                void bayes(const vec &dt, const vec &cond){
90                        vec par_cond(par->dimensionc());
91                        cond2par->filldown(cond,par_cond); // copy ut
92                        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
93                       
94                        //sample new particle
95                        est_emp.set_point(par->samplecond(par_cond));
96                        //if (evalll)
97                        vec bm_cond(bm->dimensionc());
98                        cond2bm->filldown(cond, bm_cond);// set e.g. ut
99                        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
100                        bm->bayes(dt, bm_cond);
101                        ll=bm->_ll();
102                }
103                               
104                /*! parse structure
105                \code
106                class = "MarginalizedParticle";
107                parameter_pdf = {class = 'epdf_offspring', ...};
108                bm = {class = 'bm_offspring',...};
109                \endcode
110                If rvs are set, then it checks for compatibility.
111                */
112                void from_setting(const Setting &set){
113                        MarginalizedParticleBase::from_setting ( set );
114                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
115                }
116               
117                void to_setting(Setting &set){
118                        MarginalizedParticleBase::to_setting(set);
119                        UI::save(par,set,"parameter_pdf");
120                }
121                void validate(){
122                        MarginalizedParticleBase::validate();
123                        est_emp.set_rv(par->_rv());
124                        if (est_emp.point.length()!=par->dimension())
125                                est_emp.set_point(zeros(par->dimension()));
126                        est.validate();
127                       
128                        yrv = bm->_yrv();
129                        dimy = bm->dimensiony();
130                        set_rv( concat(bm->_rv(), par->_rv()));
131                        set_dim( par->dimension()+bm->dimension());
132                       
133                        rvc = par->_rvc();
134                        rvc.add(bm->_rvc());
135                        rvc=rvc.subt(par->_rv());
136                        rvc=rvc.subt(par->_rv().copy_t(-1));
137                        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
138                       
139                        cond2bm=new datalink_part;
140                        cond2par=new datalink_part;
141                        emp2bm  =new datalink_part;
142                        emp2par =new datalink_part;
143                        cond2bm->set_connection(bm->_rvc(), rvc);
144                        cond2par->set_connection(par->_rvc(), rvc);
145                        emp2bm->set_connection(bm->_rvc(), par->_rv());
146                        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
147                       
148                        dimc = rvc._dsize();
149                };
150};
151UIREGISTER(MarginalizedParticle);
152
153//! class used in PF
154class BootstrapParticle : public BM{
155        dirac est;
156        shared_ptr<pdf> par;
157        shared_ptr<pdf> obs;
158        shared_ptr<datalink_part> cond2par;
159        shared_ptr<datalink_part> cond2obs;
160        shared_ptr<datalink_part> xt2obs;
161        shared_ptr<datalink_part> xtm2par;
162        public:
163                BM* _copy() const{return new BootstrapParticle(*this);};
164                void bayes(const vec &dt, const vec &cond){
165                        vec par_cond(par->dimensionc());
166                        cond2par->filldown(cond,par_cond); // copy ut
167                        xtm2par->filldown(est._point(),par_cond); // copy xt-1
168                       
169                        //sample new particle
170                        est.set_point(par->samplecond(par_cond));
171                        //if (evalll)
172                        vec obs_cond(obs->dimensionc());
173                        cond2obs->filldown(cond, obs_cond);// set e.g. ut
174                        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
175                        ll=obs->evallogcond(dt,obs_cond);
176                }
177                const dirac& posterior() const {return est;}
178               
179                void set_prior(const epdf *pdf0){est.set_point(pdf0->sample());}
180               
181                /*! parse structure
182                \code
183                class = "BootstrapParticle";
184                parameter_pdf = {class = 'epdf_offspring', ...};
185                observation_pdf = {class = 'epdf_offspring',...};
186                \endcode
187                If rvs are set, then it checks for compatibility.
188                */
189                void from_setting(const Setting &set){
190                        BM::from_setting ( set );
191                        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
192                        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
193                }
194                void validate(){
195                        yrv = obs->_rv();
196                        dimy = obs->dimension();
197                        set_rv( par->_rv());
198                        set_dim( par->dimension());
199                       
200                        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
201                        rvc.add(obs->_rvc()); //
202                       
203                        cond2obs=new datalink_part;
204                        cond2par=new datalink_part;
205                        xt2obs  =new datalink_part;
206                        xtm2par =new datalink_part;
207                        cond2obs->set_connection(obs->_rvc(), rvc);
208                        cond2par->set_connection(par->_rvc(), rvc);
209                        xt2obs->set_connection(obs->_rvc(), _rv());
210                        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
211                       
212                        dimc = rvc._dsize();
213                };
214};
215UIREGISTER(BootstrapParticle);
216
217
218/*!
219* \brief Trivial particle filter with proposal density equal to parameter evolution model.
220
221Posterior density is represented by a weighted empirical density (\c eEmp ).
222*/
223
224class PF : public BM {
225        //! \var log_level_enums weights
226        //! all weightes will be logged
227
228        //! \var log_level_enums menas
229        //! means of particles will be logged
230        LOG_LEVEL(PF,logweights,logmeans,logvars);
231       
232        class pf_mix: public emix_base{
233                Array<BM*> &bms;
234                public:
235                        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0){}
236                        const epdf* component(const int &i)const{return &(bms(i)->posterior());}
237                        int no_coms() const {return bms.length();}
238        };
239protected:
240        //!number of particles;
241        int n;
242        //!posterior density
243        pf_mix est;
244        //! weights;
245        vec w;
246        //! particles
247        Array<BM*> particles;
248        //! internal structure storing loglikelihood of predictions
249        vec lls;
250
251        //! which resampling method will be used
252        RESAMPLING_METHOD resmethod;
253        //! resampling threshold; in this case its meaning is minimum ratio of active particles
254        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
255        double res_threshold;
256
257        //! \name Options
258        //!@{
259        //!@}
260
261public:
262        //! \name Constructors
263        //!@{
264        PF ( ) : est(w,particles) { };
265
266        void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
267                n = n0;
268                res_threshold = res_th0;
269                resmethod = rm;
270        };
271        void set_model ( const BM *particle0, const epdf *prior) {
272                if (n>0){
273                        particles.set_length(n);
274                        for (int i=0; i<n;i++){
275                                particles(i) = particle0->_copy();
276                                particles(i)->set_prior(prior);
277                        }
278                }
279                // set values for posterior
280                est.set_rv ( particle0->posterior()._rv() );
281        };
282        void set_statistics ( const vec w0, const epdf &epdf0 ) {
283                //est.set_statistics ( w0, epdf0 );
284        };
285/*      void set_statistics ( const eEmp &epdf0 ) {
286                bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
287                est = epdf0;
288        };*/
289        //!@}
290
291        //! bayes compute weights of the
292        virtual void bayes_weights();
293        //! important part of particle filtering - decide if it is time to perform resampling
294        virtual bool do_resampling() {
295                double eff = 1.0 / ( w * w );
296                return eff < ( res_threshold*n );
297        }
298        void bayes ( const vec &yt, const vec &cond );
299        //!access function
300        vec& _lls() {
301                return lls;
302        }
303        //!access function
304        RESAMPLING_METHOD _resmethod() const {
305                return resmethod;
306        }
307        //! return correctly typed posterior (covariant return)
308        const pf_mix& posterior() const {
309                return est;
310        }
311
312        /*! configuration structure for basic PF
313        \code
314        parameter_pdf   = pdf_class;         // parameter evolution pdf
315        observation_pdf = pdf_class;         // observation pdf
316        prior           = epdf_class;         // prior probability density
317        --- optional ---
318        n               = 10;                 // number of particles
319        resmethod       = 'systematic', or 'multinomial', or 'stratified'
320                                                                                  // resampling method
321        res_threshold   = 0.5;                // resample when active particles drop below 50%
322        \endcode
323        */
324        void from_setting ( const Setting &set ) {
325                BM::from_setting ( set );
326                UI::get ( log_level, set, "log_level", UI::optional );
327               
328                shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
329               
330                n =0;
331                UI::get(n,set,"n",UI::optional);;
332                if (n>0){
333                        particles.set_length(n);
334                        for(int i=0;i<n;i++){particles(i)=bm0->_copy();}
335                        w = ones(n)/n;
336                }
337                // set resampling method
338                resmethod_from_set ( set );
339                //set drv
340
341                rvc = bm0->_rvc();
342                dimc = bm0->dimensionc();
343                BM::set_rv(bm0->_rv());
344                yrv=bm0->_yrv();
345                dimy = bm0->dimensiony();
346        }
347       
348        void log_register ( bdm::logger& L, const string& prefix ){
349                BM::log_register(L,prefix);
350                if (log_level[logweights]){
351                        L.add_vector( log_level, logweights, RV ( particles.length()), prefix); 
352                }
353                if (log_level[logmeans]){
354                        for (int i=0; i<particles.length(); i++){
355                                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
356                        }
357                }
358                if (log_level[logvars]){
359                        for (int i=0; i<particles.length(); i++){
360                                L.add_vector( log_level, logvars, RV ( particles(i)->dimension() ), prefix , i);
361                        }
362                }
363        };
364        void log_write ( ) const {
365                BM::log_write();
366                if (log_level[logweights]){
367                        log_level.store( logweights, w); 
368                }
369                if (log_level[logmeans]){
370                        for (int i=0; i<particles.length(); i++){
371                                log_level.store( logmeans, particles(i)->posterior().mean(), i);
372                        }
373                }
374                if (log_level[logvars]){
375                        for (int i=0; i<particles.length(); i++){
376                                log_level.store( logvars, particles(i)->posterior().variance(), i);
377                        }
378                }
379               
380        }
381       
382        void set_prior(const epdf *pri){
383                const emix_base *emi=dynamic_cast<const emix_base*>(pri);
384                if (emi) {
385                        bdm_assert(particles.length()>0, "initial particle is not assigned");
386                        n = emi->_w().length();
387                        int old_n = particles.length();
388                        if (n!=old_n){
389                                particles.set_length(n,true);
390                        } 
391                        for(int i=old_n;i<n;i++){particles(i)=particles(0)->_copy();}
392                       
393                        for (int i =0; i<n; i++){
394                                particles(i)->set_prior(emi->_com(i));
395                        }
396                } else {
397                        // try to find "n"
398                        bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
399                        for (int i =0; i<n; i++){
400                                particles(i)->set_prior(pri);
401                        }
402                       
403                }
404        }
405        //! auxiliary function reading parameter 'resmethod' from configuration file
406        void resmethod_from_set ( const Setting &set ) {
407                string resmeth;
408                if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
409                        if ( resmeth == "systematic" ) {
410                                resmethod = SYSTEMATIC;
411                        } else  {
412                                if ( resmeth == "multinomial" ) {
413                                        resmethod = MULTINOMIAL;
414                                } else {
415                                        if ( resmeth == "stratified" ) {
416                                                resmethod = STRATIFIED;
417                                        } else {
418                                                bdm_error ( "Unknown resampling method" );
419                                        }
420                                }
421                        }
422                } else {
423                        resmethod = SYSTEMATIC;
424                };
425                if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
426                        res_threshold = 0.9;
427                }
428                //validate();
429        }
430
431        void validate() {
432                BM::validate();
433                est.validate();
434                bdm_assert ( n>0, "empty particle pool" );
435                n = w.length();
436                lls = zeros ( n );
437
438                if ( particles(0)->_rv()._dsize() > 0 ) {
439                        bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "Mismatch of RV " +particles(0)->_rv().to_string() + 
440                        " of size (" +num2str(particles(0)->_rv()._dsize())+"and dimension of posterior ("+num2str(est.dimension()) + ")" );
441                }
442        }
443        //! resample posterior density (from outside - see MPF)
444        void resample ( ) {
445                ivec ind = zeros_i ( n );
446                bdm::resample(w,ind,resmethod);
447                // copy the internals according to ind
448                for (int i = 0; i < n; i++ ) {
449                        if ( ind ( i ) != i ) {
450                                delete particles(i);
451                                particles( i ) = particles( ind ( i ) )->_copy();
452                        }
453                        w ( i ) = 1.0 / n;
454                }
455        }
456        //! access function
457        Array<BM*>& _particles() {
458                return particles;
459        }
460        ~PF(){
461                for (int i=0; i<particles.length(); i++){delete particles(i);}
462        }
463
464};
465UIREGISTER ( PF );
466
467/*! Marginalized particle for state-space models with unknown parameters of residues distribution
468
469\f[
470\begin{eqnarray}
471x_t = g(x_{t-1}) + v_t,\\
472z_t = h(x_{t-1}) + w_t,\\
473\end{eqnarray}
474\f]
475
476This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.
477
478*/
479class NoiseParticle : public MarginalizedParticleBase{
480        protected:
481                //! function transforming xt, ut -> x_t+1
482                shared_ptr<fnc> g; // pdf for non-linear part
483                //! function transforming xt,ut -> yt
484                shared_ptr<fnc> h; // pdf for non-linear part
485               
486                RV rvx;
487                RV rvxc;
488                RV rvyc;
489               
490                //!link from condition to f
491                datalink_part cond2g;
492                //!link from condition to h
493                datalink_part cond2h;
494                //!link from xt to f
495                datalink_part x2g;
496                //!link from xt to h
497                datalink_part x2h;
498               
499        public:
500                BM* _copy() const{return new NoiseParticle(*this);};
501                void bayes(const vec &dt, const vec &cond){
502                        shared_ptr<epdf> pred_vw=bm->epredictor();
503                        shared_ptr<epdf> pred_v = pred_vw->marginal(rvx);
504                       
505                        vec vt=pred_v->sample();
506                       
507                        //new sample
508                        vec &xtm=est_emp.point;
509                        vec g_args(g->dimensionc());
510                        x2g.filldown(xtm,g_args);
511                        cond2g.filldown(cond,g_args);
512                        vec xt = g->eval(g_args) + vt;
513                        est_emp.point=xt;
514                       
515                        // residue of observation
516                        vec h_args(h->dimensionc());
517                        x2h.filldown(xt,h_args);
518                        cond2h.filldown(cond,h_args);
519                        vec wt = dt-h->eval(h_args);
520                        // the vector [v_t,w_t] is now complete
521                        bm->bayes(concat(vt,wt));
522                        ll=bm->_ll();
523                }
524                void from_setting(const Setting &set){
525                        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
526                               
527                        UI::get(g,set,"g",UI::compulsory);
528                        UI::get(h,set,"h",UI::compulsory);
529                        UI::get(rvx,set,"rvx",UI::compulsory);
530                        est_emp.set_rv(rvx);
531                       
532                        UI::get(rvxc,set,"rvxc",UI::compulsory);
533                        UI::get(rvyc,set,"rvyc",UI::compulsory);
534                }
535                void validate(){
536                        MarginalizedParticleBase::validate();
537                       
538                        dimy = h->dimension();
539                        bm->set_yrv(concat(rvx,yrv));
540                       
541                        est_emp.set_rv(rvx);
542                        est_emp.set_dim(rvx._dsize());
543                        est.validate();
544                        //
545                        //check dimensions
546                        rvc = rvxc.subt(rvx.copy_t(-1));
547                        rvc.add( rvyc);
548                        rvc=rvc.subt(rvx);
549                       
550                        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
551                        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
552                        bdm_assert(h->dimension()==rvyc._dsize(),"rvyc is not described");
553                       
554                        bdm_assert(bm->dimensiony()==g->dimension()+h->dimension(), 
555                                           "Incompatible noise estimator of dimension " +
556                                                num2str(bm->dimensiony()) + " does not match dimension of g and h, " + 
557                                                num2str(g->dimension())+" and "+ num2str(h->dimension()) );
558                                               
559                        dimc = rvc._dsize();
560                                       
561                        //establish datalinks
562                        x2g.set_connection(rvxc, rvx.copy_t(-1));
563                        cond2g.set_connection(rvxc, rvc);
564                       
565                        x2h.set_connection(rvyc, rvx);
566                        cond2h.set_connection(rvyc, rvc);
567                }
568};
569UIREGISTER(NoiseParticle);
570
571
572
573}
574#endif // KF_H
575
Note: See TracBrowser for help on using the browser.