root/library/bdm/estim/particles.h @ 638

Revision 638, 10.5 kB (checked in by smidl, 15 years ago)

redesign of PF and MPF to be more flexible and share more code

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef PARTICLES_H
14#define PARTICLES_H
15
16
17#include "../stat/exp_family.h"
18
19namespace bdm {
20
21/*!
22* \brief Trivial particle filter with proposal density equal to parameter evolution model.
23
24Posterior density is represented by a weighted empirical density (\c eEmp ).
25*/
26
27class PF : public BM {
28protected:
29        //!number of particles;
30        int n;
31        //!posterior density
32        eEmp est;
33        //! pointer into \c eEmp
34        vec &_w;
35        //! pointer into \c eEmp
36        Array<vec> &_samples;
37        //! Parameter evolution model
38        shared_ptr<mpdf> par;
39        //! Observation model
40        shared_ptr<mpdf> obs;
41        //! internal structure storing loglikelihood of predictions
42        vec lls;
43       
44        //! which resampling method will be used
45        RESAMPLING_METHOD resmethod;
46        //! resampling threshold; in this case its meaning is minimum ratio of active particles
47        //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
48        double res_threshold;
49       
50        //! \name Options
51        //!@{
52
53        //! Log all samples
54        bool opt_L_smp;
55        //! Log all samples
56        bool opt_L_wei;
57        //!@}
58
59public:
60        //! \name Constructors
61        //!@{
62        PF ( ) : est(), _w ( est._w() ), _samples ( est._samples() ), opt_L_smp ( false ), opt_L_wei ( false ) {
63                LIDs.set_size ( 5 );
64        };
65       
66        void set_parameters (int n0, double res_th0=0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
67                n = n0;
68                res_threshold = res_th0;
69                resmethod = rm;
70        };
71        void set_model ( shared_ptr<mpdf> par0, shared_ptr<mpdf> obs0) {
72                par = par0;
73                obs = obs0;
74                // set values for posterior
75                est.set_rv(par->_rv());
76        };
77        void set_statistics ( const vec w0, const epdf &epdf0 ) {
78                est.set_statistics ( w0, epdf0 );
79        };
80        void set_statistics ( const eEmp &epdf0 ) {
81                bdm_assert_debug(epdf0._rv().equal(par->_rv()),"Incompatibel input");
82                est=epdf0;
83        };
84        //!@}
85        //! Set posterior density by sampling from epdf0
86        //! Extends original BM::set_options by two more options:
87        //! \li logweights - meaning that all weightes will be logged
88        //! \li logsamples - all samples will be also logged
89        void set_options ( const string &opt ) {
90                BM::set_options ( opt );
91                opt_L_wei = ( opt.find ( "logweights" ) != string::npos );
92                opt_L_smp = ( opt.find ( "logsamples" ) != string::npos );
93        }
94        //! bayes I - generate samples and add their weights to lls
95        virtual void bayes_gensmp();
96        //! bayes II - compute weights of the
97        virtual void bayes_weights();
98        //! important part of particle filtering - decide if it is time to perform resampling
99        virtual bool do_resampling(){   
100                double eff = 1.0 / ( _w * _w );
101                return eff < ( res_threshold*n );
102        }
103        void bayes ( const vec &dt );
104        //!access function
105        vec& __w() { return _w; }
106        //!access function
107        vec& _lls() { return lls; }
108        RESAMPLING_METHOD _resmethod() const { return resmethod; }
109        //!access function
110        const eEmp& posterior() const {return est;}
111       
112        /*! configuration structure for basic PF
113        \code
114        parameter_pdf   = mpdf_class;         // parameter evolution pdf
115        observation_pdf = mpdf_class;         // observation pdf
116        prior           = epdf_class;         // prior probability density
117        --- optional ---
118        n               = 10;                 // number of particles
119        resmethod       = 'systematic', or 'multinomial', or 'stratified'
120                                                                                  // resampling method
121        res_threshold   = 0.5;                // resample when active particles drop below 50%
122        \endcode
123        */
124        void from_setting(const Setting &set){
125                par = UI::build<mpdf>(set,"parameter_pdf",UI::compulsory);
126                obs = UI::build<mpdf>(set,"observation_pdf",UI::compulsory);
127               
128                prior_from_set(set);
129                resmethod_from_set(set);
130                // set resampling method
131                //set drv
132                //find potential input - what remains in rvc when we subtract rv
133                RV u = par->_rvc().remove_time().subt( par->_rv() ); 
134                //find potential input - what remains in rvc when we subtract x_t
135                RV obs_u = obs->_rvc().remove_time().subt( par->_rv() ); 
136               
137                u.add(obs_u); // join both u, and check if they do not overlap
138               
139                set_drv(concat(obs->_rv(),u) );
140        }
141        //! auxiliary function reading parameter 'resmethod' from configuration file
142        void resmethod_from_set(const Setting &set){
143                string resmeth;
144                if (UI::get(resmeth,set,"resmethod",UI::optional)){
145                        if (resmeth=="systematic") {
146                                resmethod= SYSTEMATIC;
147                        } else  {
148                                if (resmeth=="multinomial"){
149                                        resmethod=MULTINOMIAL;
150                                } else {
151                                        if (resmeth=="stratified"){
152                                                resmethod= STRATIFIED;
153                                        } else {
154                                                bdm_error("Unknown resampling method");
155                                        }
156                                }
157                        }
158                } else {
159                        resmethod=SYSTEMATIC;
160                };
161                if(!UI::get(res_threshold, set, "res_threshold", UI::optional)){
162                        res_threshold=0.5;
163                }
164        }
165        //! load prior information from set and set internal structures accordingly
166        void prior_from_set(const Setting & set){
167                shared_ptr<epdf> pri = UI::build<epdf>(set,"prior",UI::compulsory);
168               
169                eEmp *test_emp=dynamic_cast<eEmp*>(&(*pri));
170                if (test_emp) { // given pdf is sampled
171                        est=*test_emp;
172                } else {
173                        int n;
174                        if (!UI::get(n,set,"n",UI::optional)){n=10;}
175                        // sample from prior
176                        set_statistics(ones(n)/n, *pri);
177                }
178                //validate();
179        }
180       
181        void validate(){
182                n=_w.length();
183                lls=zeros(n);
184                if (par->_rv()._dsize()>0) {
185                        bdm_assert(par->_rv()._dsize()==est.dimension(),"Mismatch of RV and dimension of posterior" );
186                }
187        }
188        //! resample posterior density (from outside - see MPF)
189        void resample(ivec &ind){
190                est.resample(ind,resmethod);
191        }
192};
193UIREGISTER(PF);
194
195/*!
196\brief Marginalized Particle filter
197
198A composition of particle filter with exact (or almost exact) bayesian models (BMs).
199The Bayesian models provide marginalized predictive density. Internaly this is achieved by virtual class MPFmpdf.
200*/
201
202class MPF : public BM  {
203    shared_ptr<PF> pf;
204        Array<BM*> BMs;
205
206        //! internal class for MPDF providing composition of eEmp with external components
207
208        class mpfepdf : public epdf  {
209                shared_ptr<PF> &pf;
210                Array<BM*> &BMs;
211        public:
212                mpfepdf (shared_ptr<PF> &pf0, Array<BM*> &BMs0): epdf(), pf(pf0), BMs(BMs0) { };
213                //! a variant of set parameters - this time, parameters are read from BMs and pf
214                void read_parameters(){
215                        rv = concat(pf->posterior()._rv(), BMs(0)->posterior()._rv());
216                        dim = pf->posterior().dimension() + BMs(0)->posterior().dimension();
217                        bdm_assert_debug(dim == rv._dsize(), "Wrong name ");
218                }
219                vec mean() const {
220                        const vec &w = pf->posterior()._w();
221                        vec pom = zeros ( BMs(0)->posterior ().dimension() );
222                        //compute mean of BMs
223                        for ( int i = 0; i < w.length(); i++ ) {
224                                pom += BMs ( i )->posterior().mean() * w ( i );
225                        }
226                        return concat ( pf->posterior().mean(), pom );
227                }
228                vec variance() const {
229                        const vec &w = pf->posterior()._w();
230                       
231                        vec pom = zeros ( BMs(0)->posterior ().dimension() );
232                        vec pom2 = zeros ( BMs(0)->posterior ().dimension() );
233                        vec mea;
234                       
235                        for ( int i = 0; i < w.length(); i++ ) {
236                                // save current mean
237                                mea = BMs ( i )->posterior().mean();
238                                pom += mea * w ( i );
239                                //compute variance
240                                pom2 += ( BMs ( i )->posterior().variance() + pow ( mea, 2 ) ) * w ( i );
241                        }
242                        return concat ( pf->posterior().variance(), pom2 - pow ( pom, 2 ) );
243                }
244               
245                void qbounds ( vec &lb, vec &ub, double perc = 0.95 ) const {
246                        //bounds on particles
247                        vec lbp;
248                        vec ubp;
249                        pf->posterior().qbounds ( lbp, ubp );
250
251                        //bounds on Components
252                        int dimC = BMs ( 0 )->posterior().dimension();
253                        int j;
254                        // temporary
255                        vec lbc ( dimC );
256                        vec ubc ( dimC );
257                        // minima and maxima
258                        vec Lbc ( dimC );
259                        vec Ubc ( dimC );
260                        Lbc = std::numeric_limits<double>::infinity();
261                        Ubc = -std::numeric_limits<double>::infinity();
262
263                        for ( int i = 0; i < BMs.length(); i++ ) {
264                                // check Coms
265                                BMs ( i )->posterior().qbounds ( lbc, ubc );
266                                //save either minima or maxima
267                                for ( j = 0; j < dimC; j++ ) {
268                                        if ( lbc ( j ) < Lbc ( j ) ) {
269                                                Lbc ( j ) = lbc ( j );
270                                        }
271                                        if ( ubc ( j ) > Ubc ( j ) ) {
272                                                Ubc ( j ) = ubc ( j );
273                                        }
274                                }
275                        }
276                        lb = concat ( lbp, Lbc );
277                        ub = concat ( ubp, Ubc );
278                }
279
280                vec sample() const {
281                        bdm_error ( "Not implemented" );
282                        return vec();
283                }
284
285                double evallog ( const vec &val ) const {
286                        bdm_error ( "not implemented" );
287                        return 0.0;
288                }
289        };
290
291        //! Density joining PF.est with conditional parts
292        mpfepdf jest;
293
294        //! Log means of BMs
295        bool opt_L_mea;
296
297public:
298        //! Default constructor.
299        MPF () :  jest (pf,BMs) {};
300        void set_parameters ( shared_ptr<mpdf> par0, shared_ptr<mpdf> obs0, int n0, RESAMPLING_METHOD rm = SYSTEMATIC ) {
301                pf->set_model ( par0, obs0); 
302                pf->set_parameters(n0, rm );
303                BMs.set_length ( n0 );
304        }
305        void set_BM ( const BM &BMcond0 ) {
306
307                int n=pf->__w().length();
308                BMs.set_length(n);
309                // copy
310                //BMcond0 .condition ( pf->posterior()._sample ( 0 ) );
311                for ( int i = 0; i < n; i++ ) {
312                        BMs ( i ) = BMcond0._copy_();
313                        BMs ( i )->condition ( pf->posterior()._sample ( i ) );
314                }
315        };
316
317        void bayes ( const vec &dt );
318        const epdf& posterior() const {
319                return jest;
320        }
321        //! Extends options understood by BM::set_options by option
322        //! \li logmeans - meaning
323        void set_options ( const string &opt ) {
324                BM::set_options(opt);
325                opt_L_mea = ( opt.find ( "logmeans" ) != string::npos );
326        }
327
328        //!Access function
329        const BM* _BM ( int i ) {
330                return BMs ( i );
331        }
332       
333        /*! configuration structure for basic PF
334        \code
335        BM              = BM_class;           // Bayesian filtr for analytical part of the model
336        parameter_pdf   = mpdf_class;         // transitional pdf for non-parametric part of the model
337        prior           = epdf_class;         // prior probability density
338        --- optional ---
339        n               = 10;                 // number of particles
340        resmethod       = 'systematic', or 'multinomial', or 'stratified'
341                                                                                  // resampling method
342        \endcode
343        */     
344        void from_setting(const Setting &set){
345                shared_ptr<mpdf> par = UI::build<mpdf>(set,"parameter_pdf",UI::compulsory);
346                shared_ptr<mpdf> obs= new mpdf(); // not used!!
347
348                pf = new PF;
349                // rpior must be set before BM
350                pf->prior_from_set(set);
351                pf->resmethod_from_set(set);
352                pf->set_model(par,obs);
353               
354                shared_ptr<BM> BM0 =UI::build<BM>(set,"BM",UI::compulsory);
355                set_BM(*BM0);
356               
357                string opt;
358                if (UI::get(opt,set,"options",UI::optional)){
359                        set_options(opt);
360                }
361                //set drv
362                //find potential input - what remains in rvc when we subtract rv
363                RV u = par->_rvc().remove_time().subt( par->_rv() );           
364                set_drv(concat(BM0->_drv(),u) );
365                validate();
366        }
367        void validate(){
368                try{
369                pf->validate();
370                } catch (std::exception &e){
371                        throw UIException("Error in PF part of MPF:");
372                }
373                jest.read_parameters();
374        }
375       
376};
377UIREGISTER(MPF);
378
379}
380#endif // KF_H
381
Note: See TracBrowser for help on using the browser.