root/library/bdm/stat/emix.h @ 979

Revision 979, 11.2 kB (checked in by smidl, 14 years ago)

estimator returns the array of posterior estimators as second argument

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public pdf {
38protected:
39        //! Nominator in the form of pdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49public:
50        //!Default constructor. By default, the given epdf is not copied!
51        //! It is assumed that this function will be used only temporarily.
52        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
53                // adjust rv and rvc
54
55                set_rv ( rv ); 
56                dim = rv._dsize();
57
58                rvc = nom0->_rv().subt ( rv );
59                dimc = rvc._dsize();
60
61                //prepare data structures
62                if ( copy ) {
63                        bdm_error ( "todo" );
64                        // destroynom = true;
65                } else {
66                        nom = nom0;
67                        destroynom = false;
68                }
69                bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
70
71                // build denominator
72                den = nom->marginal ( rvc );
73                dl.set_connection ( rv, rvc, nom0->_rv() );
74        };
75
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83
84        //! Returns a sample from the density conditioned on \c cond, \f$x \sim epdf(rv|cond)\f$. \param cond is numeric value of \c rv
85        virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);
86
87        //! Object takes ownership of nom and will destroy it
88        void ownnom() {
89                destroynom = true;
90        }
91        //! Default destructor
92        ~mratio() {
93                if ( destroynom ) {
94                        delete nom;
95                }
96        }
97
98
99private:
100        // not implemented
101        mratio ( const mratio & );
102        mratio &operator= ( const mratio & );
103};
104
105class emix; //forward
106
107/*! Base class (Interface) for mixtures
108*/
109class emix_base : public epdf {
110        protected:
111        //! reference to vector of weights
112        vec &w;
113        //! function returning ith component
114        virtual const epdf * component(const int &i) const=0;
115       
116        virtual int no_coms() const=0;
117       
118        public:
119               
120                emix_base(vec &w0): w(w0){}
121               
122                void validate ();
123               
124                vec sample() const;
125               
126                vec mean() const;
127               
128                vec variance() const;
129               
130                double evallog ( const vec &val ) const;
131               
132                vec evallog_mat ( const mat &Val ) const;
133               
134                //! Auxiliary function that returns pdflog for each component
135                mat evallog_coms ( const mat &Val ) const;
136               
137                shared_ptr<epdf> marginal ( const RV &rv ) const;
138                //! Update already existing marginal density  \c target
139                void marginal ( const RV &rv, emix &target ) const;
140                shared_ptr<pdf> condition ( const RV &rv ) const;
141               
142                //Access methods       
143                //! returns a reference to the internal weights. Use with Care!
144                vec& _w() {
145                        return w;
146                }
147               
148                const vec& _w() const {
149                        return w;
150                }
151                //!access
152                const epdf* _com(int i) const {return component(i);}
153               
154};
155
156/*!
157* \brief Mixture of epdfs
158
159Density function:
160\f[
161f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
162\f]
163where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
164
165*/
166class emix : public emix_base {
167protected:
168        //! weights of the components
169        vec weights;
170
171        //! Component (epdfs)
172        Array<shared_ptr<epdf> > Coms;
173
174public:
175        //! Default constructor
176        emix ( ) : emix_base ( weights) { }
177       
178        const epdf* component(const int &i) const {return Coms(i).get();}
179        void validate();
180       
181
182        int no_coms() const {return Coms.length(); }
183
184        //! Load from structure with elements:
185        //!  \code
186        //! { class='emix';
187        //!   pdfs = (..., ...);     // list of pdfs in the mixture
188        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
189        //! }
190        //! \endcode
191        //!@}
192        void from_setting ( const Setting &set );
193        void to_setting  (Setting  &set) const;
194
195        void set_rv ( const RV &rv ) {
196                epdf::set_rv ( rv );
197                for ( int i = 0; i < no_coms(); i++ ) {
198                        Coms( i )->set_rv ( rv );
199                }
200        }
201       
202        Array<shared_ptr<epdf> >& _Coms ( ) {
203                        return Coms;
204                }
205};
206SHAREDPTR ( emix );
207UIREGISTER ( emix );
208
209
210/*! \brief Chain rule decomposition of epdf
211
212Probability density in the form of Chain-rule decomposition:
213\[
214f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
215\]
216Note that
217*/
218class mprod: public pdf {
219private:
220        Array<shared_ptr<pdf> > pdfs;
221
222        //! Data link for each pdfs
223        Array<shared_ptr<datalink_m2m> > dls;
224
225public:
226        //! \brief Default constructor
227        mprod() { }
228
229        /*!\brief Constructor from list of mFacs
230        */
231        mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
232                set_elements ( mFacs );
233        }
234        //! Set internal \c pdfs from given values
235        void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
236
237        double evallogcond ( const vec &val, const vec &cond );
238
239        vec evallogcond_mat ( const mat &Dt, const vec &cond );
240
241        vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );
242
243        //TODO smarter...
244        vec samplecond ( const vec &cond ) {
245                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
246                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
247                vec smpi;
248                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
249                for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
250                        // generate contribution of this pdf
251                        smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
252                        // copy contribution of this pdf into smp
253                        dls ( i )->pushup ( smp, smpi );
254                }
255                return smp;
256        }
257
258        //! Load from structure with elements:
259        //!  \code
260        //! { class='mprod';
261        //!   pdfs = (..., ...);     // list of pdfs in the order of chain rule
262        //! }
263        //! \endcode
264        //!@}
265        void from_setting ( const Setting &set ) ;
266        void    to_setting  (Setting  &set) const;
267
268
269};
270UIREGISTER ( mprod );
271SHAREDPTR ( mprod );
272
273
274//! Product of independent epdfs. For dependent pdfs, use mprod.
275class eprod_base: public epdf {
276protected:
277        //! Array of indices
278        Array<datalink*> dls;
279        //! interface for a factor
280public:
281        virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL); 
282        //!number of factors
283        virtual const int no_factors() const NOT_IMPLEMENTED(0);
284        //! Default constructor
285        eprod_base () :  dls (0) {};
286        //! Set internal
287        vec mean() const;
288
289        vec variance() const;
290
291        vec sample() const;
292
293        double evallog ( const vec &val ) const;
294
295        //!Destructor
296        ~eprod_base() {
297                for ( int i = 0; i < dls.length(); i++ ) {
298                        delete dls ( i );
299                }
300        }
301        void validate() {
302                epdf::validate();
303                dls.set_length ( no_factors() );
304               
305                bool independent = true;
306                dim = 0;
307                for ( int i = 0; i < no_factors(); i++ ) {
308                        independent = rv.add ( factor ( i )->_rv() );
309                        dim += factor ( i )->dimension();
310                        bdm_assert_debug ( independent, "eprod:: given components are not independent." );
311                };
312               
313                //
314                int cumdim = 0;
315                int dimi = 0;
316                int i;
317                for ( i = 0; i < no_factors(); i++ ) {
318                        dls ( i ) = new datalink;
319                        if ( isnamed() ) { // rvs are complete
320                                dls ( i )->set_connection ( factor ( i )->_rv() , rv );
321                        } else { //rvs are not reliable
322                                dimi = factor ( i )->dimension();
323                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
324                                cumdim += dimi;
325                        }
326                }
327       
328        }
329};
330
331class eprod: public eprod_base{
332        protected:
333                Array<shared_ptr<epdf> > factors;
334        public:
335                const epdf* factor(int i) const {return factors(i).get();}
336                const int no_factors() const {return factors.length();}
337        void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
338                factors = epdfs0;
339        }
340        void from_setting(const Setting &set){
341                UI::get(factors,set,"pdfs",UI::compulsory);
342        }
343};
344UIREGISTER(eprod);
345
346//! similar to eprod but used only internally -- factors are external pointers
347class eprod_internal: public eprod_base{
348        protected:
349                Array<epdf* > factors;
350                const epdf* factor(int i) const {return factors(i);}
351                const int no_factors() const {return factors.length();}
352        public:
353                void set_parameters ( const Array<epdf *> &epdfs0) {
354                        factors = epdfs0;
355                }
356};
357
358/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
359
360*/
361class mmix : public pdf {
362protected:
363        //! Component (pdfs)
364        Array<shared_ptr<pdf> > Coms;
365        //!weights of the components
366        vec w;
367public:
368        //!Default constructor
369        mmix() : Coms ( 0 ) { };
370
371        double evallogcond ( const vec &dt, const vec &cond ) {
372                double ll = 0.0;
373                for ( int i = 0; i < Coms.length(); i++ ) {
374                        ll += Coms ( i )->evallogcond ( dt, cond );
375                }
376                return ll;
377        }
378
379        vec samplecond ( const vec &cond );
380
381        //! Load from structure with elements:
382        //!  \code
383        //! { class='mmix';
384        //!   pdfs = (..., ...);     // list of pdfs in the mixture
385        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
386        //! }
387        //! \endcode
388        //!@}
389        void from_setting ( const Setting &set );
390        void    to_setting  (Setting  &set) const;
391        virtual void validate();
392};
393SHAREDPTR ( mmix );
394UIREGISTER ( mmix );
395
396
397//! Base class for all BM running as parallel update of internal BMs
398
399class ProdBMBase : public BM {
400        protected :
401                Array<vec_from_vec> bm_yt;
402                Array<vec_from_2vec> bm_cond;
403                class eprod_bm : public eprod_base {
404                        ProdBMBase & pb;
405                        public :
406                        eprod_bm(ProdBMBase &pb0): pb(pb0){}
407                        const epdf* factor(int i ) const {return &(pb.bm(i)->posterior());}
408                        const int no_factors() const {return pb.no_bms();}
409                } est;
410        public:
411                ProdBMBase():est(*this){}
412                virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
413                virtual int no_bms() const {return 0;}
414                const epdf& posterior() const {return est;}
415                void set_prior(const epdf *pri){
416                        const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
417                        if (ep){
418                                bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
419                                        num2str(no_bms()) + "BMs");
420                                for (int i=0; i<no_bms(); i++){
421                                        bm(i)->set_prior(ep->factor(i));
422                                }
423                        }
424                }
425               
426                void validate() {
427                        est.validate();
428                        BM::validate();
429                        // set links
430                        bm_yt.set_length(no_bms());
431                        bm_cond.set_length(no_bms());
432                       
433                        //
434                       
435                        for (int i=0; i<no_bms(); i++){
436                                yrv.add(bm(i)->_yrv());
437                                rvc.add(bm(i)->_rvc());
438                        }
439                        rvc=rvc.subt(yrv);
440                       
441                        dimy = yrv._dsize();
442                        dimc = rvc._dsize();
443                       
444                        for (int i=0; i<no_bms(); i++){
445                                bm_yt(i).connect(bm(i)->_yrv(), yrv);
446                                bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
447                        }
448                }
449                void bayes(const vec &dt, const vec &cond){
450                        ll=0;
451                        for(int i=0;i<no_bms(); i++){
452                                bm_yt(i).update(dt);
453                                bm_cond(i).update(dt,cond);
454                                bm(i)->bayes(bm_yt(i), bm_cond(i));
455                        }
456                }
457               
458};
459
460class ProdBM: public ProdBMBase{
461        protected:
462                Array<shared_ptr<BM> > BMs;
463        public:
464                virtual BM* bm(int i) {return BMs(i).get();}
465                virtual int no_bms() const {return BMs.length();}
466                void from_setting(const Setting &set){
467                        BM::from_setting(set);
468                        UI::get(BMs,set,"BMs");
469                }
470                void to_setting(Setting &set) const{
471                        BM::to_setting(set);
472                        UI::save(BMs,set,"BMs");
473                }
474               
475};
476UIREGISTER(ProdBM);
477
478}
479#endif //MX_H
Note: See TracBrowser for help on using the browser.