root/library/bdm/estim/particles.h @ 1176

Revision 1170, 26.8 kB (checked in by smidl, 14 years ago)

New noise particle + memory leak fix

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../estim/arx_ext.h"
18#include "../stat/emix.h"
19
20namespace bdm {
21
22//! \brief Abstract class for Marginalized Particles
23class MarginalizedParticleBase : public BM {
24protected:
25    //! discrte particle
26    dirac est_emp;
27    //! internal Bayes Model
28    shared_ptr<BM> bm;
29
30    //! \brief Internal class for custom posterior - product of empirical and exact part
31    class eprod_2:public eprod_base {
32    protected:
33        MarginalizedParticleBase &mp;
34    public:
35        eprod_2(MarginalizedParticleBase &m):mp(m) {}
36        const epdf* factor(int i) const {
37            return (i==0) ? &mp.bm->posterior() : &mp.est_emp;
38        }
39        const int no_factors() const {
40            return 2;
41        }
42    } est;
43
44public:
45    MarginalizedParticleBase():est(*this) {};
46    MarginalizedParticleBase(const MarginalizedParticleBase &m2):BM(m2),est(*this) {
47        bm = m2.bm->_copy();
48        est_emp = m2.est_emp;
49        est.validate();
50        validate();
51    };
52    void bayes(const vec &dt, const vec &cond) NOT_IMPLEMENTED_VOID;
53
54    const eprod_2& posterior() const {
55        return est;
56    }
57
58    void set_prior(const epdf *pdf0) {
59        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
60        if (ep) { // full prior
61            bdm_assert(ep->no_factors()==2,"Incompatible prod");
62            bm->set_prior(ep->factor(0));
63            est_emp.set_point(ep->factor(1)->sample());
64        } else {
65            // assume prior is only for emp;
66            est_emp.set_point(pdf0->sample());
67        }
68    }
69
70
71    /*! Create object from the following structure
72
73    \code
74    class = "MarginalizedParticleBase";
75    bm = configuration of bdm::BM;          % any offspring of BM, bdm::BM::from_setting
76    --- inherited fields ---
77    bdm::BM::from_setting
78    \endcode
79    */
80    void from_setting(const Setting &set) {
81        BM::from_setting ( set );
82        bm = UI::build<BM> ( set, "bm", UI::compulsory );
83    }
84    void validate() {
85        BM::validate();
86        //est.validate(); --pdfs not known
87        bdm_assert(bm,"Internal BM is not given");
88    }
89};
90
91//! \brief Particle with marginalized subspace, used in PF
92class MarginalizedParticle : public MarginalizedParticleBase {
93protected:
94    //! pdf with for transitional par
95    shared_ptr<pdf> par; // pdf for non-linear part
96    //! link from this to bm
97    shared_ptr<datalink_part> cond2bm;
98    //! link from cond to par
99    shared_ptr<datalink_part> cond2par;
100    //! link from emp 2 par
101    shared_ptr<datalink_part> emp2bm;
102    //! link from emp 2 par
103    shared_ptr<datalink_part> emp2par;
104
105public:
106    BM* _copy() const {
107        return new MarginalizedParticle(*this);
108    };
109    void bayes(const vec &dt, const vec &cond) {
110        vec par_cond(par->dimensionc());
111        cond2par->filldown(cond,par_cond); // copy ut
112        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1
113
114        //sample new particle
115        est_emp.set_point(par->samplecond(par_cond));
116        //if (evalll)
117        vec bm_cond(bm->dimensionc());
118        cond2bm->filldown(cond, bm_cond);// set e.g. ut
119        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
120        bm->bayes(dt, bm_cond);
121        ll=bm->_ll();
122    }
123
124    /*! Create object from the following structure
125
126    \code
127    class = "MarginalizedParticle";
128    parameter_pdf = configuration of bdm::epdf;          % any offspring of epdf, bdm::epdf::from_setting
129    --- inherited fields ---
130    bdm::MarginalizedParticleBase::from_setting
131    \endcode
132    */   
133    void from_setting(const Setting &set) {
134        MarginalizedParticleBase::from_setting ( set );
135        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
136    }
137
138    void to_setting(Setting &set)const {
139        MarginalizedParticleBase::to_setting(set);
140        UI::save(par,set,"parameter_pdf");
141                UI::save(bm,set,"bm");
142    }
143    void validate() {
144        MarginalizedParticleBase::validate();
145        est_emp.set_rv(par->_rv());
146        if (est_emp.point.length()!=par->dimension())
147            est_emp.set_point(zeros(par->dimension()));
148        est.validate();
149
150        yrv = bm->_yrv();
151        dimy = bm->dimensiony();
152        set_rv( concat(bm->_rv(), par->_rv()));
153        set_dim( par->dimension()+bm->dimension());
154
155        rvc = par->_rvc();
156        rvc.add(bm->_rvc());
157        rvc=rvc.subt(par->_rv());
158        rvc=rvc.subt(par->_rv().copy_t(-1));
159        rvc=rvc.subt(bm->_rv().copy_t(-1)); //
160
161        cond2bm=new datalink_part;
162        cond2par=new datalink_part;
163        emp2bm  =new datalink_part;
164        emp2par =new datalink_part;
165        cond2bm->set_connection(bm->_rvc(), rvc);
166        cond2par->set_connection(par->_rvc(), rvc);
167        emp2bm->set_connection(bm->_rvc(), par->_rv());
168        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));
169
170        dimc = rvc._dsize();
171    };
172};
173UIREGISTER(MarginalizedParticle);
174
175//! Internal class which is used in PF
176class BootstrapParticle : public BM {
177    dirac est;
178    shared_ptr<pdf> par;
179    shared_ptr<pdf> obs;
180    shared_ptr<datalink_part> cond2par;
181    shared_ptr<datalink_part> cond2obs;
182    shared_ptr<datalink_part> xt2obs;
183    shared_ptr<datalink_part> xtm2par;
184public:
185    BM* _copy() const {
186        return new BootstrapParticle(*this);
187    };
188    void bayes(const vec &dt, const vec &cond) {
189        vec par_cond(par->dimensionc());
190        cond2par->filldown(cond,par_cond); // copy ut
191        xtm2par->filldown(est._point(),par_cond); // copy xt-1
192
193        //sample new particle
194        est.set_point(par->samplecond(par_cond));
195        //if (evalll)
196        vec obs_cond(obs->dimensionc());
197        cond2obs->filldown(cond, obs_cond);// set e.g. ut
198        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
199        ll=obs->evallogcond(dt,obs_cond);
200    }
201    const dirac& posterior() const {
202        return est;
203    }
204
205    void set_prior(const epdf *pdf0) {
206        est.set_point(pdf0->sample());
207    }
208
209    /*! Create object from the following structure
210    \code
211    class = "BootstrapParticle";
212    parameter_pdf = configuration of bdm::epdf;      % any offspring of epdf, bdm::epdf::from_setting
213    observation_pdf = configuration of bdm::epdf;    % any offspring of epdf, bdm::epdf::from_setting
214    --- inherited fields ---
215    bdm::BM::from_setting
216    \endcode
217    */
218    void from_setting(const Setting &set) {
219        BM::from_setting ( set );
220        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
221        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
222    }
223
224    void validate() {
225        yrv = obs->_rv();
226        dimy = obs->dimension();
227        set_rv( par->_rv());
228        set_dim( par->dimension());
229
230        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
231        rvc.add(obs->_rvc()); //
232
233        cond2obs=new datalink_part;
234        cond2par=new datalink_part;
235        xt2obs  =new datalink_part;
236        xtm2par =new datalink_part;
237        cond2obs->set_connection(obs->_rvc(), rvc);
238        cond2par->set_connection(par->_rvc(), rvc);
239        xt2obs->set_connection(obs->_rvc(), _rv());
240        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));
241
242        dimc = rvc._dsize();
243    };
244};
245UIREGISTER(BootstrapParticle);
246
247
248/*!
249* \brief Trivial particle filter with proposal density equal to parameter evolution model.
250
251Posterior density is represented by a weighted empirical density (\c eEmp ).
252*/
253
254class PF : public BM {
255    //! \var log_level_enums logweights
256    //! all weightes will be logged
257
258    //! \var log_level_enums logmeans
259    //! means of particles will be logged
260    LOG_LEVEL(PF,logweights,logmeans,logvars);
261
262    class pf_mix: public emix_base {
263        Array<BM*> &bms;
264    public:
265        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0) {}
266        const epdf* component(const int &i)const {
267            return &(bms(i)->posterior());
268        }
269        int no_coms() const {
270            return bms.length();
271        }
272    };
273protected:
274    //!number of particles;
275    int n;
276    //!posterior density
277    pf_mix est;
278    //! weights;
279    vec w;
280    //! particles
281    Array<BM*> particles;
282    //! internal structure storing loglikelihood of predictions
283    vec lls;
284
285    //! which resampling method will be used
286    RESAMPLING_METHOD resmethod;
287    //! resampling threshold; in this case its meaning is minimum ratio of active particles
288    //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
289    double res_threshold;
290
291    //! \name Options
292    //!@{
293    //!@}
294
295public:
296    //! \name Constructors
297    //!@{
298    PF ( ) : est(w,particles) { };
299
300    void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
301        n = n0;
302        res_threshold = res_th0;
303        resmethod = rm;
304    };
305    void set_model ( const BM *particle0, const epdf *prior) {
306        if (n>0) {
307            particles.set_length(n);
308            for (int i=0; i<n; i++) {
309                particles(i) = particle0->_copy();
310                particles(i)->set_prior(prior);
311            }
312        }
313        // set values for posterior
314        est.set_rv ( particle0->posterior()._rv() );
315    };
316    void set_statistics ( const vec w0, const epdf &epdf0 ) {
317        //est.set_statistics ( w0, epdf0 );
318    };
319    /*    void set_statistics ( const eEmp &epdf0 ) {
320            bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
321            est = epdf0;
322        };*/
323    //!@}
324
325    //! bayes compute weights of the
326    virtual void bayes_weights();
327    //! important part of particle filtering - decide if it is time to perform resampling
328    virtual bool do_resampling() {
329        double eff = 1.0 / ( w * w );
330        return eff < ( res_threshold*n );
331    }
332    void bayes ( const vec &yt, const vec &cond );
333    //!access function
334    vec& _lls() {
335        return lls;
336    }
337    //!access function
338    RESAMPLING_METHOD _resmethod() const {
339        return resmethod;
340    }
341    //! return correctly typed posterior (covariant return)
342    const pf_mix& posterior() const {
343        return est;
344    }
345
346    /*! configuration structure for basic PF
347    \code
348    particle        = bdm::BootstrapParticle;       % one bayes rule for each point in the empirical support
349      - or -        = bdm::MarginalizedParticle;    % (in case of Marginalized Particle filtering
350    prior           = epdf_class;                   % prior probability density on the empirical variable
351    --- optional ---
352    n               = 10;                           % number of particles
353    resmethod       = 'systematic', or 'multinomial', or 'stratified'
354                                                    % resampling method
355    res_threshold   = 0.5;                          % resample when active particles drop below 50%
356    \endcode
357    */
358    void from_setting ( const Setting &set ) {
359        BM::from_setting ( set );
360        UI::get ( log_level, set, "log_level", UI::optional );
361
362        shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);
363
364        n =0;
365        UI::get(n,set,"n",UI::optional);;
366        if (n>0) {
367            particles.set_length(n);
368            for(int i=0; i<n; i++) {
369                particles(i)=bm0->_copy();
370            }
371            w = ones(n)/n;
372        }
373        shared_ptr<epdf> pri = UI::build<epdf>(set,"prior");
374        set_prior(pri.get());
375        // set resampling method
376        resmethod_from_set ( set );
377        //set drv
378
379        rvc = bm0->_rvc();
380        dimc = bm0->dimensionc();
381        BM::set_rv(bm0->_rv());
382        yrv=bm0->_yrv();
383        dimy = bm0->dimensiony();
384    }
385   
386    void to_setting(Setting &set) const{
387                BM::to_setting(set);
388                UI::save(particles, set,"particles");
389                UI::save(w,set,"w");
390        }
391
392    void log_register ( bdm::logger& L, const string& prefix ) {
393        BM::log_register(L,prefix);
394        if (log_level[logweights]) {
395            L.add_vector( log_level, logweights, RV ( particles.length()), prefix);
396        }
397        if (log_level[logmeans]) {
398            for (int i=0; i<particles.length(); i++) {
399                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
400            }
401        }
402        if (log_level[logvars]) {
403            for (int i=0; i<particles.length(); i++) {
404                L.add_vector( log_level, logvars, RV ( particles(i)->dimension() ), prefix , i);
405            }
406        }
407    };
408    void log_write ( ) const {
409        BM::log_write();
410        if (log_level[logweights]) {
411            log_level.store( logweights, w);
412        }
413        if (log_level[logmeans]) {
414            for (int i=0; i<particles.length(); i++) {
415                log_level.store( logmeans, particles(i)->posterior().mean(), i);
416            }
417        }
418        if (log_level[logvars]) {
419            for (int i=0; i<particles.length(); i++) {
420                log_level.store( logvars, particles(i)->posterior().variance(), i);
421            }
422        }
423
424    }
425
426    void set_prior(const epdf *pri) {
427        const emix_base *emi=dynamic_cast<const emix_base*>(pri);
428        if (emi) {
429            bdm_assert(particles.length()>0, "initial particle is not assigned");
430            n = emi->_w().length();
431            int old_n = particles.length();
432            if (n!=old_n) {
433                particles.set_length(n,true);
434            }
435            for(int i=old_n; i<n; i++) {
436                particles(i)=particles(0)->_copy();
437            }
438
439            for (int i =0; i<n; i++) {
440                particles(i)->set_prior(emi->_com(i));
441            }
442        } else {
443            // try to find "n"
444            bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
445            for (int i =0; i<n; i++) {
446                particles(i)->set_prior(pri);
447            }
448
449        }
450    }
451    //! auxiliary function reading parameter 'resmethod' from configuration file
452    void resmethod_from_set ( const Setting &set ) {
453        string resmeth;
454        if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
455            if ( resmeth == "systematic" ) {
456                resmethod = SYSTEMATIC;
457            } else  {
458                if ( resmeth == "multinomial" ) {
459                    resmethod = MULTINOMIAL;
460                } else {
461                    if ( resmeth == "stratified" ) {
462                        resmethod = STRATIFIED;
463                    } else {
464                        bdm_error ( "Unknown resampling method" );
465                    }
466                }
467            }
468        } else {
469            resmethod = SYSTEMATIC;
470        };
471        if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
472            res_threshold = 0.9;
473        }
474        //validate();
475    }
476
477    void validate() {
478        BM::validate();
479        est.validate();
480        bdm_assert ( n>0, "empty particle pool" );
481        n = w.length();
482        lls = zeros ( n );
483
484        if ( particles(0)->_rv()._dsize() > 0 ) {
485            bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "MPF:: Mismatch of RV " +particles(0)->_rv().to_string() +
486                          " of size (" +num2str(particles(0)->_rv()._dsize())+") and dimension of posterior ("+num2str(est.dimension()) + ")" );
487        }
488    }
489    //! resample posterior density (from outside - see MPF)
490    void resample ( ) {
491        ivec ind = zeros_i ( n );
492        bdm::resample(w,ind,resmethod);
493        // copy the internals according to ind
494        for (int i = 0; i < n; i++ ) {
495            if ( ind ( i ) != i ) {
496                delete particles(i);
497                particles( i ) = particles( ind ( i ) )->_copy();
498            }
499            w ( i ) = 1.0 / n;
500        }
501    }
502    //! access function
503    Array<BM*>& _particles() {
504        return particles;
505    }
506    ~PF() {
507        for (int i=0; i<particles.length(); i++) {
508            delete particles(i);
509        }
510    }
511
512};
513UIREGISTER ( PF );
514
515/*! Marginalized particle for state-space models with unknown parameters of distribuution of residues on \f$v_t\f$.
516
517\f{eqnarray*}{
518    x_t &=& g(x_{t-1}) + v_t,\\
519    y_t &\sim &fy(x_t),
520    \f}
521
522    This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.
523
524    */
525class NoiseParticleX : public MarginalizedParticleBase {
526protected:
527    //! function transforming xt, ut -> x_t+1
528    shared_ptr<fnc> g; // pdf for non-linear part
529    //! function transforming xt,ut -> yt
530    shared_ptr<pdf> fy; // pdf for non-linear part
531
532    RV rvx;
533    RV rvxc;
534    RV rvyc;
535
536    //!link from condition to f
537    datalink_part cond2g;
538    //!link from condition to h
539    datalink_part cond2fy;
540    //!link from xt to f
541    datalink_part x2g;
542    //!link from xt to h
543    datalink_part x2fy;
544
545public:
546    BM* _copy() const {
547        return new NoiseParticleX(*this);
548    };
549    void bayes(const vec &dt, const vec &cond) {
550        //shared_ptr<epdf> pred_v=bm->epredictor();
551
552        //vec vt=pred_v->sample();
553                vec vt = bm->samplepred();
554
555        //new sample
556        vec &xtm=est_emp.point;
557        vec g_args(g->dimensionc());
558        x2g.filldown(xtm,g_args);
559        cond2g.filldown(cond,g_args);
560        vec xt = g->eval(g_args) + vt;
561        est_emp.point=xt;
562
563        // the vector [v_t] updates bm,
564        bm->bayes(vt);
565
566        // residue of observation
567        vec fy_args(fy->dimensionc());
568        x2fy.filldown(xt,fy_args);
569        cond2fy.filldown(cond,fy_args);
570
571        ll= fy->evallogcond(dt,fy_args);
572    }
573    void from_setting(const Setting &set) {
574        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
575
576        g=UI::build<fnc>(set,"g",UI::compulsory);
577        fy=UI::build<pdf>(set,"fy",UI::compulsory);
578        UI::get(rvx,set,"rvx",UI::compulsory);
579        est_emp.set_rv(rvx);
580
581        UI::get(rvxc,set,"rvxc",UI::compulsory);
582        UI::get(rvyc,set,"rvyc",UI::compulsory);
583
584    }
585    void to_setting (Setting &set) const {
586                MarginalizedParticleBase::to_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
587                UI::save(g,set,"g");
588                UI::save(fy,set,"fy");
589                UI::save(bm,set,"bm");
590        }
591    void validate() {
592        MarginalizedParticleBase::validate();
593
594        dimy = fy->dimension();
595        bm->set_yrv(rvx);
596
597        est_emp.set_rv(rvx);
598        est_emp.set_dim(rvx._dsize());
599        est.validate();
600        //
601        //check dimensions
602        rvc = rvxc.subt(rvx.copy_t(-1));
603        rvc.add( rvyc);
604        rvc=rvc.subt(rvx);
605
606        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
607        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
608        bdm_assert(fy->dimensionc()==rvyc._dsize(),"rvyc is not described");
609
610        bdm_assert(bm->dimensiony()==g->dimension(),
611                   "Incompatible noise estimator of dimension " +
612                   num2str(bm->dimensiony()) + " does not match dimension of g , " +
613                   num2str(g->dimension()));
614
615        dimc = rvc._dsize();
616
617        //establish datalinks
618        x2g.set_connection(rvxc, rvx.copy_t(-1));
619        cond2g.set_connection(rvxc, rvc);
620
621        x2fy.set_connection(rvyc, rvx);
622        cond2fy.set_connection(rvyc, rvc);
623    }
624};
625UIREGISTER(NoiseParticleX);
626
627/*! Marginalized particle for state-space models with unknown parameters of distribuution of residues on \f$v_t\f$ and \f$ w_t \f$.
628
629\f{eqnarray*}{
630        x_t &=& g(x_{t-1}) + v_t,\\
631        y_t &= &h(x_t)+w_t,
632        \f}
633       
634        This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.
635       
636        */
637class NoiseParticleXY : public BM {
638        //! discrte particle
639        dirac est_emp;
640        //! internal Bayes Model
641        shared_ptr<BM> bmx;
642        shared_ptr<BM> bmy;
643       
644        //! \brief Internal class for custom posterior - product of empirical and exact part
645        class eprod_3:public eprod_base {
646                protected:
647                        NoiseParticleXY &mp;
648                public:
649                        eprod_3(NoiseParticleXY &m):mp(m) {}
650                        const epdf* factor(int i) const {
651                                if (i==0) return &mp.bmx->posterior() ;
652                                if(i==1) return &mp.bmy->posterior();
653                                return &mp.est_emp;
654                        }
655                        const int no_factors() const {
656                                return 3;
657                        }
658        } est;
659
660        protected:
661                //! function transforming xt, ut -> x_t+1
662                shared_ptr<fnc> g; // pdf for non-linear part
663                //! function transforming xt,ut -> yt
664                shared_ptr<fnc> h; // pdf for non-linear part
665               
666                RV rvx;
667                RV rvxc;
668                RV rvyc;
669               
670                //!link from condition to f
671                datalink_part cond2g;
672                //!link from condition to h
673                datalink_part cond2h;
674                //!link from xt to f
675                datalink_part x2g;
676                //!link from xt to h
677                datalink_part x2h;
678               
679        public:
680                NoiseParticleXY():est(*this) {};
681                NoiseParticleXY(const NoiseParticleXY &m2):BM(m2),est(*this),h(m2.h),g(m2.g), rvx(m2.rvx),rvxc(m2.rvxc),rvyc(m2.rvyc) {
682                        bmx = m2.bmx->_copy();
683                        bmy = m2.bmy->_copy();
684                        est_emp = m2.est_emp;
685                        //est.validate();
686                        validate();
687                };
688               
689                const eprod_3& posterior() const {
690                        return est;
691                }
692               
693                void set_prior(const epdf *pdf0) {
694                        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
695                        if (ep) { // full prior
696                                bdm_assert(ep->no_factors()==2,"Incompatible prod");
697                                bmx->set_prior(ep->factor(0));
698                                bmy->set_prior(ep->factor(1));
699                                est_emp.set_point(ep->factor(2)->sample());
700                        } else {
701                                // assume prior is only for emp;
702                                est_emp.set_point(pdf0->sample());
703                        }
704                }
705                               
706                BM* _copy() const {
707                        return new NoiseParticleXY(*this);
708                };
709               
710                void bayes(const vec &dt, const vec &cond) {
711                        //shared_ptr<epdf> pred_v=bm->epredictor();
712                       
713                        //vec vt=pred_v->sample();
714                        vec vt = bmx->samplepred();
715                       
716                        //new sample
717                        vec &xtm=est_emp.point;
718                        vec g_args(g->dimensionc());
719                        x2g.filldown(xtm,g_args);
720                        cond2g.filldown(cond,g_args);
721                        vec xt = g->eval(g_args) + vt;
722                        est_emp.point=xt;
723                       
724                        // the vector [v_t] updates bm,
725                        bmx->bayes(vt);
726                       
727                        // residue of observation
728                        vec h_args(h->dimensionc());
729                        x2h.filldown(xt,h_args);
730                        cond2h.filldown(cond,h_args);
731                       
732                        bmy->bayes(h->eval(h_args)-dt);
733                        ll= bmy->_ll();
734                }
735                void from_setting(const Setting &set) {
736                        BM::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
737                        bmx = UI::build<BM>(set,"bmx",UI::compulsory);
738                       
739                        bmy = UI::build<BM>(set,"bmy",UI::compulsory);
740                        g=UI::build<fnc>(set,"g",UI::compulsory);
741                        h=UI::build<fnc>(set,"h",UI::compulsory);
742                        UI::get(rvx,set,"rvx",UI::compulsory);
743                        est_emp.set_rv(rvx);
744                       
745                        UI::get(rvxc,set,"rvxc",UI::compulsory);
746                        UI::get(rvyc,set,"rvyc",UI::compulsory);
747                       
748                }
749                void to_setting (Setting &set) const {
750                        BM::to_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
751                        UI::save(g,set,"g");
752                        UI::save(h,set,"h");
753                        UI::save(bmx,set,"bmx");
754                        UI::save(bmy,set,"bmy");
755                }
756                void validate() {
757                        BM::validate();
758                       
759                        dimy = h->dimension();
760//                      bmx->set_yrv(rvx);
761//                      bmy->
762                       
763                        est_emp.set_rv(rvx);
764                        est_emp.set_dim(rvx._dsize());
765                        est.validate();
766                        //
767                        //check dimensions
768                        rvc = rvxc.subt(rvx.copy_t(-1));
769                        rvc.add( rvyc);
770                        rvc=rvc.subt(rvx);
771                       
772                        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
773                        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
774                        bdm_assert(h->dimensionc()==rvyc._dsize(),"rvyc is not described");
775                       
776                        bdm_assert(bmx->dimensiony()==g->dimension(),
777                                           "Incompatible noise estimator of dimension " +
778                                           num2str(bmx->dimensiony()) + " does not match dimension of g , " +
779                                           num2str(g->dimension())
780                                           );
781                                           
782                        dimc = rvc._dsize();
783                       
784                        //establish datalinks
785                        x2g.set_connection(rvxc, rvx.copy_t(-1));
786                        cond2g.set_connection(rvxc, rvc);
787                       
788                        x2h.set_connection(rvyc, rvx);
789                        cond2h.set_connection(rvyc, rvc);
790                }               
791               
792};
793UIREGISTER(NoiseParticleXY);
794
795/*! Marginalized particle for state-space models with unknown parameters of residues distribution
796
797\f{eqnarray*}{
798    x_t &=& g(x_{t-1}) + v_t,\\
799    z_t &= &h(x_{t-1}) + w_t,
800    \f}
801
802    This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.
803
804    */
805class NoiseParticle : public MarginalizedParticleBase {
806protected:
807    //! function transforming xt, ut -> x_t+1
808    shared_ptr<fnc> g; // pdf for non-linear part
809    //! function transforming xt,ut -> yt
810    shared_ptr<fnc> h; // pdf for non-linear part
811
812    RV rvx;
813    RV rvxc;
814    RV rvyc;
815
816    //!link from condition to f
817    datalink_part cond2g;
818    //!link from condition to h
819    datalink_part cond2h;
820    //!link from xt to f
821    datalink_part x2g;
822    //!link from xt to h
823    datalink_part x2h;
824
825public:
826    BM* _copy() const {
827        return new NoiseParticle(*this);
828    };
829    void bayes(const vec &dt, const vec &cond) {
830        shared_ptr<epdf> pred_vw=bm->epredictor();
831        shared_ptr<epdf> pred_v = pred_vw->marginal(rvx);
832
833        vec vt=pred_v->sample();
834
835        //new sample
836        vec &xtm=est_emp.point;
837        vec g_args(g->dimensionc());
838        x2g.filldown(xtm,g_args);
839        cond2g.filldown(cond,g_args);
840        vec xt = g->eval(g_args) + vt;
841        est_emp.point=xt;
842
843        // residue of observation
844        vec h_args(h->dimensionc());
845        x2h.filldown(xt,h_args);
846        cond2h.filldown(cond,h_args);
847        vec wt = dt-h->eval(h_args);
848        // the vector [v_t,w_t] is now complete
849        bm->bayes(concat(vt,wt));
850        ll=bm->_ll();
851    }
852    void from_setting(const Setting &set) {
853        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...
854
855        UI::get(g,set,"g",UI::compulsory);
856        UI::get(h,set,"h",UI::compulsory);
857        UI::get(rvx,set,"rvx",UI::compulsory);
858        est_emp.set_rv(rvx);
859
860        UI::get(rvxc,set,"rvxc",UI::compulsory);
861        UI::get(rvyc,set,"rvyc",UI::compulsory);
862
863    }
864    void validate() {
865        MarginalizedParticleBase::validate();
866
867        dimy = h->dimension();
868        bm->set_yrv(concat(rvx,yrv));
869
870        est_emp.set_rv(rvx);
871        est_emp.set_dim(rvx._dsize());
872        est.validate();
873        //
874        //check dimensions
875        rvc = rvxc.subt(rvx.copy_t(-1));
876        rvc.add( rvyc);
877        rvc=rvc.subt(rvx);
878
879        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
880        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
881        bdm_assert(h->dimension()==rvyc._dsize(),"rvyc is not described");
882
883        bdm_assert(bm->dimensiony()==g->dimension()+h->dimension(),
884                   "Incompatible noise estimator of dimension " +
885                   num2str(bm->dimensiony()) + " does not match dimension of g and h, " +
886                   num2str(g->dimension())+" and "+ num2str(h->dimension()) );
887
888        dimc = rvc._dsize();
889
890        //establish datalinks
891        x2g.set_connection(rvxc, rvx.copy_t(-1));
892        cond2g.set_connection(rvxc, rvc);
893
894        x2h.set_connection(rvyc, rvx);
895        cond2h.set_connection(rvyc, rvc);
896    }
897};
898UIREGISTER(NoiseParticle);
899
900
901}
902#endif // KF_H
903
Note: See TracBrowser for help on using the browser.