root/bdm/stat/emix.h @ 378

Revision 378, 11.2 kB (checked in by smidl, 15 years ago)

details and compositepdf changes

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MX_H
14#define MX_H
15
16#define LOG2  0.69314718055995 
17
18#include "libEF.h"
19//#include <std>
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41        //!Denominator in the form of epdf
42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
45        //!datalink between conditional and nom
46        datalink_m2e dl;
47public:
48        //!Default constructor. By default, the given epdf is not copied!
49        //! It is assumed that this function will be used only temporarily.
50        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
51                // adjust rv and rvc
52                rvc = nom0->_rv().subt ( rv );
53                dimc = rvc._dsize();
54                ep = new epdf;
55                ep->set_parameters(rv._dsize());
56                ep->set_rv(rv);
57               
58                //prepare data structures
59                if ( copy ) {it_error ( "todo" ); destroynom=true; }
60                else { nom = nom0; destroynom = false; }
61                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
62               
63                // build denominator
64                den = nom->marginal ( rvc );
65                dl.set_connection(rv,rvc,nom0->_rv());
66        };
67        double evallogcond ( const vec &val, const vec &cond ) {
68                double tmp;
69                vec nom_val ( ep->dimension() + dimc );
70                dl.pushup_cond ( nom_val,val,cond );
71                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
72                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
73                return tmp;
74        }
75        //! Object takes ownership of nom and will destroy it
76        void ownnom() {destroynom=true;}
77        //! Default destructor
78        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
79};
80
81/*!
82* \brief Mixture of epdfs
83
84Density function:
85\f[
86f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
87\f]
88where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
89
90*/
91class emix : public epdf {
92protected:
93        //! weights of the components
94        vec w;
95        //! Component (epdfs)
96        Array<epdf*> Coms;
97        //!Flag if owning Coms
98        bool destroyComs;
99public:
100        //!Default constructor
101        emix ( ) : epdf ( ) {};
102        //! Set weights \c w and components \c Coms
103        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
104        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
105
106        vec sample() const;
107        vec mean() const {
108                int i; vec mu = zeros ( dim );
109                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
110                return mu;
111        }
112        vec variance() const {
113                //non-central moment
114                vec mom2 = zeros ( dim );
115                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
116                //central moment
117                return mom2-pow ( mean(),2 );
118        }
119        double evallog ( const vec &val ) const {
120                int i;
121                double sum = 0.0;
122                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
123                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
124                double tmp=log ( sum );
125                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
126                return tmp;
127        };
128        vec evallog_m ( const mat &Val ) const {
129                vec x=zeros ( Val.cols() );
130                for ( int i = 0; i < w.length(); i++ ) {
131                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
132                }
133                return log ( x );
134        };
135        //! Auxiliary function that returns pdflog for each component
136        mat evallog_M ( const mat &Val ) const {
137                mat X ( w.length(), Val.cols() );
138                for ( int i = 0; i < w.length(); i++ ) {
139                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
140                }
141                return X;
142        };
143
144        emix* marginal ( const RV &rv ) const;
145        mratio* condition ( const RV &rv ) const; //why not mratio!!
146
147//Access methods
148        //! returns a pointer to the internal mean value. Use with Care!
149        vec& _w() {return w;}
150        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
151        //! Auxiliary function for taking ownership of the Coms()
152        void ownComs() {destroyComs=true;}
153
154        //!access function
155        epdf* _Coms ( int i ) {return Coms ( i );}
156        void set_rv(const RV &rv){
157                epdf::set_rv(rv);
158                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
159        }
160};
161
162
163/*!
164* \brief Mixture of egiws
165
166*/
167class egiwmix : public egiw {
168protected:
169        //! weights of the components
170        vec w;
171        //! Component (epdfs)
172        Array<egiw*> Coms;
173        //!Flag if owning Coms
174        bool destroyComs;
175public:
176        //!Default constructor
177        egiwmix ( ) : egiw ( ) {};
178
179        //! Set weights \c w and components \c Coms
180        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
181        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
182
183        //!return expected value
184        vec mean() const;
185
186        //!return a sample from the density
187        vec sample() const;
188
189        //!return the expected variance
190        vec variance() const;
191
192        // TODO!!! Defined to follow ANSI and/or for future development
193        void mean_mat ( mat &M, mat&R ) const {};
194        double evallog_nn ( const vec &val ) const {return 0;};
195        double lognc () const {return 0;};
196        emix* marginal ( const RV &rv ) const;
197
198//Access methods
199        //! returns a pointer to the internal mean value. Use with Care!
200        vec& _w() {return w;}
201        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
202        //! Auxiliary function for taking ownership of the Coms()
203        void ownComs() {destroyComs=true;}
204
205        //!access function
206        egiw* _Coms ( int i ) {return Coms ( i );}
207
208        void set_rv(const RV &rv){
209                egiw::set_rv(rv);
210                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
211        }
212
213        //! Approximation of a GiW mix by a single GiW pdf
214        egiw* approx();
215};
216
217/*! \brief Chain rule decomposition of epdf
218
219Probability density in the form of Chain-rule decomposition:
220\[
221f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
222\]
223Note that
224*/
225class mprod: public compositepdf, public mpdf {
226protected:
227        //! pointers to epdfs - shortcut to mpdfs().posterior()
228        Array<epdf*> epdfs;
229        //! Data link for each mpdfs
230        Array<datalink_m2m*> dls;
231        //! dummy ep
232        epdf dummy;
233public:
234        /*!\brief Constructor from list of mFacs,
235        */
236        mprod (){};
237        mprod (Array<mpdf*> mFacs ){set_elements( mFacs );};
238        void set_elements(Array<mpdf*> mFacs ) {
239               
240                set_elements(mFacs);
241               
242                ep=&dummy;
243                RV rv=getrv ( true );
244                set_rv ( rv );dummy.set_parameters ( rv._dsize() );
245                setrvc ( ep->_rv(),rvc );
246                // rv and rvc established = > we can link them with mpdfs
247                for ( int i = 0;i < mpdfs.length(); i++ ) {
248                        dls ( i ) = new datalink_m2m;
249                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
250                }
251
252                for ( int i=0; i<mpdfs.length(); i++ ) {
253                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
254                }
255        };
256
257        double evallogcond ( const vec &val, const vec &cond ) {
258                int i;
259                double res = 0.0;
260                for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
261                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
262                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
263                                                }
264                                                // add logarithms
265                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
266                        res += mpdfs ( i )->evallogcond (
267                                   dls ( i )->pushdown ( val ),
268                                   dls ( i )->get_cond ( val, cond )
269                               );
270                }
271                return res;
272        }
273        //TODO smarter...
274        vec samplecond ( const vec &cond ) {
275                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
276                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
277                vec smpi;
278                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
279                for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
280                        if ( mpdfs ( i )->dimensionc() ) {
281                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
282                        }
283                        smpi = epdfs ( i )->sample();
284                        // copy contribution of this pdf into smp
285                        dls ( i )->pushup ( smp, smpi );
286                        // add ith likelihood contribution
287                }
288                return smp;
289        }
290        mat samplecond ( const vec &cond,  int N ) {
291                mat Smp ( dimension(),N );
292                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
293                return Smp;
294        }
295
296        ~mprod() {};
297};
298
299//! Product of independent epdfs. For dependent pdfs, use mprod.
300class eprod: public epdf {
301protected:
302        //! Components (epdfs)
303        Array<const epdf*> epdfs;
304        //! Array of indeces
305        Array<datalink*> dls;
306public:
307        eprod () : epdfs ( 0 ),dls ( 0 ) {};
308        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
309                epdfs=epdfs0;//.set_length ( epdfs0.length() );
310                dls.set_length ( epdfs.length() );
311
312                bool independent=true;
313                if ( named ) {
314                        for ( int i=0;i<epdfs.length();i++ ) {
315                                independent=rv.add ( epdfs ( i )->_rv() );
316                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
317                        }
318                        dim=rv._dsize();
319                }
320                else {
321                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
322                                dim+=epdfs ( i )->dimension();
323                        }
324                }
325                //
326                int cumdim=0;
327                int dimi=0;
328                int i;
329                for ( i=0;i<epdfs.length();i++ ) {
330                        dls ( i ) = new datalink;
331                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
332                        else {
333                                dimi = epdfs ( i )->dimension();
334                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
335                                cumdim+=dimi;
336                        }
337                }
338        }
339
340        vec mean() const {
341                vec tmp ( dim );
342                for ( int i=0;i<epdfs.length();i++ ) {
343                        vec pom = epdfs ( i )->mean();
344                        dls ( i )->pushup ( tmp, pom );
345                }
346                return tmp;
347        }
348        vec variance() const {
349                vec tmp ( dim ); //second moment
350                for ( int i=0;i<epdfs.length();i++ ) {
351                        vec pom = epdfs ( i )->mean();
352                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
353                }
354                return tmp-pow ( mean(),2 );
355        }
356        vec sample() const {
357                vec tmp ( dim );
358                for ( int i=0;i<epdfs.length();i++ ) {
359                        vec pom = epdfs ( i )->sample();
360                        dls ( i )->pushup ( tmp, pom );
361                }
362                return tmp;
363        }
364        double evallog ( const vec &val ) const {
365                double tmp=0;
366                for ( int i=0;i<epdfs.length();i++ ) {
367                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
368                }
369                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
370                return tmp;
371        }
372        //!access function
373        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
374
375        //!Destructor
376        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
377};
378
379
380/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
381
382*/
383class mmix : public mpdf {
384protected:
385        //! Component (epdfs)
386        Array<mpdf*> Coms;
387        //!Internal epdf
388        emix Epdf;
389public:
390        //!Default constructor
391        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
392        //! Set weights \c w and components \c R
393        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
394                Array<epdf*> Eps ( Coms.length() );
395
396                for ( int i = 0;i < Coms.length();i++ ) {
397                        Eps ( i ) = & ( Coms ( i )->_epdf() );
398                }
399                Epdf.set_parameters ( w, Eps );
400        };
401
402        void condition ( const vec &cond ) {
403                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
404        };
405};
406
407}
408#endif //MX_H
Note: See TracBrowser for help on using the browser.