root/library/bdm/estim/particles.h @ 653

Revision 653, 10.5 kB (checked in by smidl, 15 years ago)

corrections in Kalman and particles

  • Property svn:eol-style set to native
RevLine 
[8]1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef PARTICLES_H
14#define PARTICLES_H
[8]15
[262]16
[384]17#include "../stat/exp_family.h"
[8]18
[270]19namespace bdm {
[8]20
21/*!
[32]22* \brief Trivial particle filter with proposal density equal to parameter evolution model.
[8]23
[32]24Posterior density is represented by a weighted empirical density (\c eEmp ).
[8]25*/
[32]26
27class PF : public BM {
[8]28protected:
[32]29        //!number of particles;
30        int n;
31        //!posterior density
32        eEmp est;
33        //! pointer into \c eEmp
34        vec &_w;
35        //! pointer into \c eEmp
36        Array<vec> &_samples;
37        //! Parameter evolution model
[638]38        shared_ptr<mpdf> par;
[32]39        //! Observation model
[638]40        shared_ptr<mpdf> obs;
41        //! internal structure storing loglikelihood of predictions
42        vec lls;
43       
[283]44        //! which resampling method will be used
45        RESAMPLING_METHOD resmethod;
[638]46        //! resampling threshold; in this case its meaning is minimum ratio of active particles
47        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
48        double res_threshold;
49       
[281]50        //! \name Options
51        //!@{
52
53        //! Log all samples
54        bool opt_L_smp;
[283]55        //! Log all samples
56        bool opt_L_wei;
[281]57        //!@}
58
[8]59public:
[270]60        //! \name Constructors
61        //!@{
[477]62        PF ( ) : est(), _w ( est._w() ), _samples ( est._samples() ), opt_L_smp ( false ), opt_L_wei ( false ) {
63                LIDs.set_size ( 5 );
64        };
[638]65       
66        void set_parameters (int n0, double res_th0=0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
[477]67                n = n0;
[638]68                res_threshold = res_th0;
[477]69                resmethod = rm;
70        };
[638]71        void set_model ( shared_ptr<mpdf> par0, shared_ptr<mpdf> obs0) {
72                par = par0;
73                obs = obs0;
74                // set values for posterior
75                est.set_rv(par->_rv());
76        };
[488]77        void set_statistics ( const vec w0, const epdf &epdf0 ) {
[477]78                est.set_statistics ( w0, epdf0 );
79        };
[638]80        void set_statistics ( const eEmp &epdf0 ) {
81                bdm_assert_debug(epdf0._rv().equal(par->_rv()),"Incompatibel input");
82                est=epdf0;
83        };
[270]84        //!@}
[33]85        //! Set posterior density by sampling from epdf0
[638]86        //! Extends original BM::set_options by two more options:
87        //! \li logweights - meaning that all weightes will be logged
88        //! \li logsamples - all samples will be also logged
[283]89        void set_options ( const string &opt ) {
[477]90                BM::set_options ( opt );
91                opt_L_wei = ( opt.find ( "logweights" ) != string::npos );
92                opt_L_smp = ( opt.find ( "logsamples" ) != string::npos );
[281]93        }
[638]94        //! bayes I - generate samples and add their weights to lls
95        virtual void bayes_gensmp();
96        //! bayes II - compute weights of the
97        virtual void bayes_weights();
98        //! important part of particle filtering - decide if it is time to perform resampling
99        virtual bool do_resampling(){   
100                double eff = 1.0 / ( _w * _w );
101                return eff < ( res_threshold*n );
102        }
[32]103        void bayes ( const vec &dt );
[225]104        //!access function
[638]105        vec& __w() { return _w; }
106        //!access function
107        vec& _lls() { return lls; }
108        RESAMPLING_METHOD _resmethod() const { return resmethod; }
109        //!access function
110        const eEmp& posterior() const {return est;}
111       
112        /*! configuration structure for basic PF
113        \code
114        parameter_pdf   = mpdf_class;         // parameter evolution pdf
115        observation_pdf = mpdf_class;         // observation pdf
116        prior           = epdf_class;         // prior probability density
117        --- optional ---
118        n               = 10;                 // number of particles
119        resmethod       = 'systematic', or 'multinomial', or 'stratified'
120                                                                                  // resampling method
121        res_threshold   = 0.5;                // resample when active particles drop below 50%
122        \endcode
123        */
124        void from_setting(const Setting &set){
125                par = UI::build<mpdf>(set,"parameter_pdf",UI::compulsory);
126                obs = UI::build<mpdf>(set,"observation_pdf",UI::compulsory);
127               
128                prior_from_set(set);
129                resmethod_from_set(set);
130                // set resampling method
131                //set drv
132                //find potential input - what remains in rvc when we subtract rv
133                RV u = par->_rvc().remove_time().subt( par->_rv() ); 
134                //find potential input - what remains in rvc when we subtract x_t
135                RV obs_u = obs->_rvc().remove_time().subt( par->_rv() ); 
136               
137                u.add(obs_u); // join both u, and check if they do not overlap
138               
139                set_drv(concat(obs->_rv(),u) );
[477]140        }
[638]141        //! auxiliary function reading parameter 'resmethod' from configuration file
142        void resmethod_from_set(const Setting &set){
143                string resmeth;
144                if (UI::get(resmeth,set,"resmethod",UI::optional)){
145                        if (resmeth=="systematic") {
146                                resmethod= SYSTEMATIC;
147                        } else  {
148                                if (resmeth=="multinomial"){
149                                        resmethod=MULTINOMIAL;
150                                } else {
151                                        if (resmeth=="stratified"){
152                                                resmethod= STRATIFIED;
153                                        } else {
154                                                bdm_error("Unknown resampling method");
155                                        }
156                                }
157                        }
158                } else {
159                        resmethod=SYSTEMATIC;
160                };
161                if(!UI::get(res_threshold, set, "res_threshold", UI::optional)){
162                        res_threshold=0.5;
163                }
164        }
165        //! load prior information from set and set internal structures accordingly
166        void prior_from_set(const Setting & set){
167                shared_ptr<epdf> pri = UI::build<epdf>(set,"prior",UI::compulsory);
168               
169                eEmp *test_emp=dynamic_cast<eEmp*>(&(*pri));
170                if (test_emp) { // given pdf is sampled
171                        est=*test_emp;
172                } else {
173                        int n;
174                        if (!UI::get(n,set,"n",UI::optional)){n=10;}
175                        // sample from prior
176                        set_statistics(ones(n)/n, *pri);
177                }
178                //validate();
179        }
180       
181        void validate(){
182                n=_w.length();
183                lls=zeros(n);
184                if (par->_rv()._dsize()>0) {
185                        bdm_assert(par->_rv()._dsize()==est.dimension(),"Mismatch of RV and dimension of posterior" );
186                }
187        }
188        //! resample posterior density (from outside - see MPF)
189        void resample(ivec &ind){
190                est.resample(ind,resmethod);
191        }
[653]192        Array<vec>& __samples(){return _samples;}
[8]193};
[638]194UIREGISTER(PF);
[8]195
196/*!
[32]197\brief Marginalized Particle filter
[8]198
[638]199A composition of particle filter with exact (or almost exact) bayesian models (BMs).
200The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFmpdf.
[8]201*/
202
[638]203class MPF : public BM  {
[653]204        protected:
205        shared_ptr<PF> pf;
[638]206        Array<BM*> BMs;
[32]207
208        //! internal class for MPDF providing composition of eEmp with external components
209
[477]210        class mpfepdf : public epdf  {
[638]211                shared_ptr<PF> &pf;
212                Array<BM*> &BMs;
[8]213        public:
[638]214                mpfepdf (shared_ptr<PF> &pf0, Array<BM*> &BMs0): epdf(), pf(pf0), BMs(BMs0) { };
215                //! a variant of set parameters - this time, parameters are read from BMs and pf
216                void read_parameters(){
217                        rv = concat(pf->posterior()._rv(), BMs(0)->posterior()._rv());
218                        dim = pf->posterior().dimension() + BMs(0)->posterior().dimension();
219                        bdm_assert_debug(dim == rv._dsize(), "Wrong name ");
[283]220                }
[32]221                vec mean() const {
[638]222                        const vec &w = pf->posterior()._w();
223                        vec pom = zeros ( BMs(0)->posterior ().dimension() );
224                        //compute mean of BMs
225                        for ( int i = 0; i < w.length(); i++ ) {
226                                pom += BMs ( i )->posterior().mean() * w ( i );
[477]227                        }
[638]228                        return concat ( pf->posterior().mean(), pom );
[32]229                }
[229]230                vec variance() const {
[638]231                        const vec &w = pf->posterior()._w();
232                       
233                        vec pom = zeros ( BMs(0)->posterior ().dimension() );
234                        vec pom2 = zeros ( BMs(0)->posterior ().dimension() );
235                        vec mea;
236                       
237                        for ( int i = 0; i < w.length(); i++ ) {
238                                // save current mean
239                                mea = BMs ( i )->posterior().mean();
240                                pom += mea * w ( i );
241                                //compute variance
242                                pom2 += ( BMs ( i )->posterior().variance() + pow ( mea, 2 ) ) * w ( i );
[270]243                        }
[638]244                        return concat ( pf->posterior().variance(), pom2 - pow ( pom, 2 ) );
[229]245                }
[638]246               
[477]247                void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const {
[283]248                        //bounds on particles
249                        vec lbp;
250                        vec ubp;
[638]251                        pf->posterior().qbounds ( lbp, ubp );
[32]252
[283]253                        //bounds on Components
[638]254                        int dimC = BMs ( 0 )->posterior().dimension();
[283]255                        int j;
256                        // temporary
[477]257                        vec lbc ( dimC );
258                        vec ubc ( dimC );
[283]259                        // minima and maxima
[477]260                        vec Lbc ( dimC );
261                        vec Ubc ( dimC );
[283]262                        Lbc = std::numeric_limits<double>::infinity();
263                        Ubc = -std::numeric_limits<double>::infinity();
264
[638]265                        for ( int i = 0; i < BMs.length(); i++ ) {
[283]266                                // check Coms
[638]267                                BMs ( i )->posterior().qbounds ( lbc, ubc );
268                                //save either minima or maxima
[477]269                                for ( j = 0; j < dimC; j++ ) {
270                                        if ( lbc ( j ) < Lbc ( j ) ) {
271                                                Lbc ( j ) = lbc ( j );
272                                        }
273                                        if ( ubc ( j ) > Ubc ( j ) ) {
274                                                Ubc ( j ) = ubc ( j );
275                                        }
[283]276                                }
277                        }
[477]278                        lb = concat ( lbp, Lbc );
279                        ub = concat ( ubp, Ubc );
[283]280                }
281
[477]282                vec sample() const {
[565]283                        bdm_error ( "Not implemented" );
284                        return vec();
[477]285                }
[32]286
[477]287                double evallog ( const vec &val ) const {
[565]288                        bdm_error ( "not implemented" );
[477]289                        return 0.0;
290                }
[32]291        };
292
[281]293        //! Density joining PF.est with conditional parts
[32]294        mpfepdf jest;
295
[281]296        //! Log means of BMs
297        bool opt_L_mea;
[283]298
[32]299public:
300        //! Default constructor.
[638]301        MPF () :  jest (pf,BMs) {};
302        void set_parameters ( shared_ptr<mpdf> par0, shared_ptr<mpdf> obs0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
303                pf->set_model ( par0, obs0); 
304                pf->set_parameters(n0, rm );
[283]305                BMs.set_length ( n0 );
306        }
[638]307        void set_BM ( const BM &BMcond0 ) {
[32]308
[638]309                int n=pf->__w().length();
310                BMs.set_length(n);
[283]311                // copy
[638]312                //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
[477]313                for ( int i = 0; i < n; i++ ) {
[638]314                        BMs ( i ) = BMcond0._copy_();
315                        BMs ( i )->condition ( pf->posterior()._sample ( i ) );
[477]316                }
[32]317        };
318
319        void bayes ( const vec &dt );
[477]320        const epdf& posterior() const {
321                return jest;
322        }
[638]323        //! Extends options understood by BM::set_options by option
324        //! \li logmeans - meaning
[283]325        void set_options ( const string &opt ) {
[638]326                BM::set_options(opt);
[477]327                opt_L_mea = ( opt.find ( "logmeans" ) != string::npos );
[32]328        }
[283]329
[225]330        //!Access function
[536]331        const BM* _BM ( int i ) {
[477]332                return BMs ( i );
333        }
[638]334       
335        /*! configuration structure for basic PF
336        \code
337        BM              = BM_class;           // Bayesian filtr for analytical part of the model
338        parameter_pdf   = mpdf_class;         // transitional pdf for non-parametric part of the model
339        prior           = epdf_class;         // prior probability density
340        --- optional ---
341        n               = 10;                 // number of particles
342        resmethod       = 'systematic', or 'multinomial', or 'stratified'
343                                                                                  // resampling method
344        \endcode
345        */     
346        void from_setting(const Setting &set){
347                shared_ptr<mpdf> par = UI::build<mpdf>(set,"parameter_pdf",UI::compulsory);
348                shared_ptr<mpdf> obs= new mpdf(); // not used!!
[8]349
[638]350                pf = new PF;
351                // rpior must be set before BM
352                pf->prior_from_set(set);
353                pf->resmethod_from_set(set);
354                pf->set_model(par,obs);
355               
356                shared_ptr<BM> BM0 =UI::build<BM>(set,"BM",UI::compulsory);
357                set_BM(*BM0);
358               
359                string opt;
360                if (UI::get(opt,set,"options",UI::optional)){
361                        set_options(opt);
362                }
363                //set drv
364                //find potential input - what remains in rvc when we subtract rv
365                RV u = par->_rvc().remove_time().subt( par->_rv() );           
366                set_drv(concat(BM0->_drv(),u) );
367                validate();
[32]368        }
[638]369        void validate(){
370                try{
371                pf->validate();
372                } catch (std::exception &e){
373                        throw UIException("Error in PF part of MPF:");
374                }
375                jest.read_parameters();
[32]376        }
[638]377       
378};
379UIREGISTER(MPF);
[32]380
381}
[8]382#endif // KF_H
383
Note: See TracBrowser for help on using the browser.