root/library/bdm/stat/emix.h @ 461

Revision 461, 12.1 kB (checked in by vbarta, 15 years ago)

mpdf (& its dependencies) reformat: now using shared_ptr, moved virtual method bodies to .cpp

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995 
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41        //!Denominator in the form of epdf
42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
45        //!datalink between conditional and nom
46        datalink_m2e dl;
47public:
48        //!Default constructor. By default, the given epdf is not copied!
49        //! It is assumed that this function will be used only temporarily.
50        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
51                // adjust rv and rvc
52                rvc = nom0->_rv().subt ( rv );
53                dimc = rvc._dsize();
54                set_ep(shared_ptr<epdf>(new epdf));
55                e()->set_parameters(rv._dsize());
56                e()->set_rv(rv);
57               
58                //prepare data structures
59                if ( copy ) {it_error ( "todo" ); destroynom=true; }
60                else { nom = nom0; destroynom = false; }
61                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
62               
63                // build denominator
64                den = nom->marginal ( rvc );
65                dl.set_connection(rv,rvc,nom0->_rv());
66        };
67        double evallogcond ( const vec &val, const vec &cond ) {
68                double tmp;
69                vec nom_val ( e()->dimension() + dimc );
70                dl.pushup_cond ( nom_val,val,cond );
71                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
72                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
73                return tmp;
74        }
75        //! Object takes ownership of nom and will destroy it
76        void ownnom() {destroynom=true;}
77        //! Default destructor
78        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
79};
80
81/*!
82* \brief Mixture of epdfs
83
84Density function:
85\f[
86f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
87\f]
88where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
89
90*/
91class emix : public epdf {
92protected:
93        //! weights of the components
94        vec w;
95        //! Component (epdfs)
96        Array<epdf*> Coms;
97        //!Flag if owning Coms
98        bool destroyComs;
99public:
100        //!Default constructor
101        emix ( ) : epdf ( ) {};
102        //! Set weights \c w and components \c Coms
103        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
104        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
105
106        vec sample() const;
107        vec mean() const {
108                int i; vec mu = zeros ( dim );
109                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
110                return mu;
111        }
112        vec variance() const {
113                //non-central moment
114                vec mom2 = zeros ( dim );
115                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
116                //central moment
117                return mom2-pow ( mean(),2 );
118        }
119        double evallog ( const vec &val ) const {
120                int i;
121                double sum = 0.0;
122                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
123                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
124                double tmp=log ( sum );
125                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
126                return tmp;
127        };
128        vec evallog_m ( const mat &Val ) const {
129                vec x=zeros ( Val.cols() );
130                for ( int i = 0; i < w.length(); i++ ) {
131                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
132                }
133                return log ( x );
134        };
135        //! Auxiliary function that returns pdflog for each component
136        mat evallog_M ( const mat &Val ) const {
137                mat X ( w.length(), Val.cols() );
138                for ( int i = 0; i < w.length(); i++ ) {
139                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
140                }
141                return X;
142        };
143
144        emix* marginal ( const RV &rv ) const;
145        mratio* condition ( const RV &rv ) const; //why not mratio!!
146
147//Access methods
148        //! returns a pointer to the internal mean value. Use with Care!
149        vec& _w() {return w;}
150        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
151        //! Auxiliary function for taking ownership of the Coms()
152        void ownComs() {destroyComs=true;}
153
154        //!access function
155        epdf* _Coms ( int i ) {return Coms ( i );}
156        void set_rv(const RV &rv){
157                epdf::set_rv(rv);
158                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
159        }
160};
161
162
163/*!
164* \brief Mixture of egiws
165
166*/
167class egiwmix : public egiw {
168protected:
169        //! weights of the components
170        vec w;
171        //! Component (epdfs)
172        Array<egiw*> Coms;
173        //!Flag if owning Coms
174        bool destroyComs;
175public:
176        //!Default constructor
177        egiwmix ( ) : egiw ( ) {};
178
179        //! Set weights \c w and components \c Coms
180        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
181        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
182
183        //!return expected value
184        vec mean() const;
185
186        //!return a sample from the density
187        vec sample() const;
188
189        //!return the expected variance
190        vec variance() const;
191
192        // TODO!!! Defined to follow ANSI and/or for future development
193        void mean_mat ( mat &M, mat&R ) const {};
194        double evallog_nn ( const vec &val ) const {return 0;};
195        double lognc () const {return 0;};
196        emix* marginal ( const RV &rv ) const;
197
198//Access methods
199        //! returns a pointer to the internal mean value. Use with Care!
200        vec& _w() {return w;}
201        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
202        //! Auxiliary function for taking ownership of the Coms()
203        void ownComs() {destroyComs=true;}
204
205        //!access function
206        egiw* _Coms ( int i ) {return Coms ( i );}
207
208        void set_rv(const RV &rv){
209                egiw::set_rv(rv);
210                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
211        }
212
213        //! Approximation of a GiW mix by a single GiW pdf
214        egiw* approx();
215};
216
217/*! \brief Chain rule decomposition of epdf
218
219Probability density in the form of Chain-rule decomposition:
220\[
221f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
222\]
223Note that
224*/
225class mprod: public compositepdf, public mpdf {
226protected:
227        //! pointers to epdfs - shortcut to mpdfs().posterior()
228        Array<epdf*> epdfs;
229        //! Data link for each mpdfs
230        Array<datalink_m2m*> dls;
231
232        //! dummy ep
233        shared_ptr<epdf> dummy;
234
235public:
236        /*!\brief Constructor from list of mFacs,
237        */
238        mprod():dummy(new epdf()) { }
239        mprod (Array<mpdf*> mFacs):
240            dummy(new epdf()) {
241            set_elements(mFacs);
242        }
243
244        void set_elements(Array<mpdf*> mFacs , bool own=false) {
245               
246                compositepdf::set_elements(mFacs,own);
247                dls.set_size(mFacs.length());
248                epdfs.set_size(mFacs.length());
249                               
250                set_ep(dummy);
251                RV rv=getrv ( true );
252                set_rv ( rv );
253                dummy->set_parameters(rv._dsize());
254                setrvc(e()->_rv(), rvc);
255                // rv and rvc established = > we can link them with mpdfs
256                for ( int i = 0;i < mpdfs.length(); i++ ) {
257                        dls ( i ) = new datalink_m2m;
258                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
259                }
260
261                for ( int i=0; i<mpdfs.length(); i++ ) {
262                        epdfs(i) = mpdfs(i)->e();
263                }
264        };
265
266        double evallogcond ( const vec &val, const vec &cond ) {
267                int i;
268                double res = 0.0;
269                for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
270                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
271                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
272                                                }
273                                                // add logarithms
274                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
275                        res += mpdfs ( i )->evallogcond (
276                                   dls ( i )->pushdown ( val ),
277                                   dls ( i )->get_cond ( val, cond )
278                               );
279                }
280                return res;
281        }
282        vec evallogcond_m(const mat &Dt, const vec &cond) {
283                vec tmp(Dt.cols());
284                for(int i=0;i<Dt.cols(); i++){
285                        tmp(i) = evallogcond(Dt.get_col(i),cond);
286                }
287                return tmp;
288        };
289        vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {
290                vec tmp(Dt.length());
291                for(int i=0;i<Dt.length(); i++){
292                        tmp(i) = evallogcond(Dt(i),cond);
293                }
294                return tmp;             
295        };
296
297
298        //TODO smarter...
299        vec samplecond ( const vec &cond ) {
300            //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
301            vec smp= std::numeric_limits<double>::infinity() * ones ( e()->dimension() );
302                vec smpi;
303                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
304                for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
305                        if ( mpdfs ( i )->dimensionc() ) {
306                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
307                        }
308                        smpi = epdfs ( i )->sample();
309                        // copy contribution of this pdf into smp
310                        dls ( i )->pushup ( smp, smpi );
311                        // add ith likelihood contribution
312                }
313                return smp;
314        }
315        mat samplecond ( const vec &cond,  int N ) {
316                mat Smp ( dimension(),N );
317                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
318                return Smp;
319        }
320
321        ~mprod() {};
322        //! Load from structure with elements:
323        //!  \code
324        //! { class='mprod';
325        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
326        //! }
327        //! \endcode
328        //!@}
329        void from_setting(const Setting &set){
330                Array<mpdf*> Atmp; //temporary Array
331                UI::get(Atmp,set, "mpdfs");
332                set_elements(Atmp,true);
333        }
334       
335};
336UIREGISTER(mprod);
337
338//! Product of independent epdfs. For dependent pdfs, use mprod.
339class eprod: public epdf {
340protected:
341        //! Components (epdfs)
342        Array<const epdf*> epdfs;
343        //! Array of indeces
344        Array<datalink*> dls;
345public:
346        eprod () : epdfs ( 0 ),dls ( 0 ) {};
347        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
348                epdfs=epdfs0;//.set_length ( epdfs0.length() );
349                dls.set_length ( epdfs.length() );
350
351                bool independent=true;
352                if ( named ) {
353                        for ( int i=0;i<epdfs.length();i++ ) {
354                                independent=rv.add ( epdfs ( i )->_rv() );
355                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
356                        }
357                        dim=rv._dsize();
358                }
359                else {
360                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
361                                dim+=epdfs ( i )->dimension();
362                        }
363                }
364                //
365                int cumdim=0;
366                int dimi=0;
367                int i;
368                for ( i=0;i<epdfs.length();i++ ) {
369                        dls ( i ) = new datalink;
370                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
371                        else {
372                                dimi = epdfs ( i )->dimension();
373                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
374                                cumdim+=dimi;
375                        }
376                }
377        }
378
379        vec mean() const {
380                vec tmp ( dim );
381                for ( int i=0;i<epdfs.length();i++ ) {
382                        vec pom = epdfs ( i )->mean();
383                        dls ( i )->pushup ( tmp, pom );
384                }
385                return tmp;
386        }
387        vec variance() const {
388                vec tmp ( dim ); //second moment
389                for ( int i=0;i<epdfs.length();i++ ) {
390                        vec pom = epdfs ( i )->mean();
391                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
392                }
393                return tmp-pow ( mean(),2 );
394        }
395        vec sample() const {
396                vec tmp ( dim );
397                for ( int i=0;i<epdfs.length();i++ ) {
398                        vec pom = epdfs ( i )->sample();
399                        dls ( i )->pushup ( tmp, pom );
400                }
401                return tmp;
402        }
403        double evallog ( const vec &val ) const {
404                double tmp=0;
405                for ( int i=0;i<epdfs.length();i++ ) {
406                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
407                }
408                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
409                return tmp;
410        }
411        //!access function
412        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
413
414        //!Destructor
415        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
416};
417
418
419/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
420
421*/
422class mmix : public mpdf {
423protected:
424        //! Component (epdfs)
425        Array<mpdf*> Coms;
426
427        //!Internal epdf
428        shared_ptr<emix> iepdf;
429
430public:
431        //!Default constructor
432        mmix():iepdf(new emix()) {
433            set_ep(iepdf);
434        }
435
436        //! Set weights \c w and components \c R
437        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
438                Array<epdf*> Eps ( Coms.length() );
439
440                for ( int i = 0;i < Coms.length();i++ ) {
441                        Eps(i) = Coms(i)->e();
442                }
443
444                iepdf->set_parameters(w, Eps);
445        }
446
447        void condition ( const vec &cond ) {
448                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
449        };
450};
451
452}
453#endif //MX_H
Note: See TracBrowser for help on using the browser.