root/bdm/stat/emix.h @ 339

Revision 333, 11.1 kB (checked in by dedecius, 15 years ago)

Implementation of GiW mixtures and their approximation by a single mixture

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MX_H
14#define MX_H
15
16#define LOG2  0.69314718055995 
17
18#include "libBM.h"
19#include "libEF.h"
20//#include <std>
21
22namespace bdm {
23
24//this comes first because it is used inside emix!
25
26/*! \brief Class representing ratio of two densities
27which arise e.g. by applying the Bayes rule.
28It represents density in the form:
29\f[
30f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
31\f]
32where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
33
34In particular this type of arise by conditioning of a mixture model.
35
36At present the only supported operation is evallogcond().
37 */
38class mratio: public mpdf {
39protected:
40        //! Nominator in the form of mpdf
41        const epdf* nom;
42        //!Denominator in the form of epdf
43        epdf* den;
44        //!flag for destructor
45        bool destroynom;
46        //!datalink between conditional and nom
47        datalink_m2e dl;
48public:
49        //!Default constructor. By default, the given epdf is not copied!
50        //! It is assumed that this function will be used only temporarily.
51        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
52                // adjust rv and rvc
53                rvc = nom0->_rv().subt ( rv );
54                dimc = rvc._dsize();
55                ep = new epdf;
56                ep->set_parameters(rv._dsize());
57                ep->set_rv(rv);
58               
59                //prepare data structures
60                if ( copy ) {it_error ( "todo" ); destroynom=true; }
61                else { nom = nom0; destroynom = false; }
62                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
63               
64                // build denominator
65                den = nom->marginal ( rvc );
66                dl.set_connection(rv,rvc,nom0->_rv());
67        };
68        double evallogcond ( const vec &val, const vec &cond ) {
69                double tmp;
70                vec nom_val ( ep->dimension() + dimc );
71                dl.pushup_cond ( nom_val,val,cond );
72                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
73                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
74                return tmp;
75        }
76        //! Object takes ownership of nom and will destroy it
77        void ownnom() {destroynom=true;}
78        //! Default destructor
79        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
80};
81
82/*!
83* \brief Mixture of epdfs
84
85Density function:
86\f[
87f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
88\f]
89where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
90
91*/
92class emix : public epdf {
93protected:
94        //! weights of the components
95        vec w;
96        //! Component (epdfs)
97        Array<epdf*> Coms;
98        //!Flag if owning Coms
99        bool destroyComs;
100public:
101        //!Default constructor
102        emix ( ) : epdf ( ) {};
103        //! Set weights \c w and components \c Coms
104        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
105        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
106
107        vec sample() const;
108        vec mean() const {
109                int i; vec mu = zeros ( dim );
110                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
111                return mu;
112        }
113        vec variance() const {
114                //non-central moment
115                vec mom2 = zeros ( dim );
116                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
117                //central moment
118                return mom2-pow ( mean(),2 );
119        }
120        double evallog ( const vec &val ) const {
121                int i;
122                double sum = 0.0;
123                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
124                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
125                double tmp=log ( sum );
126                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
127                return tmp;
128        };
129        vec evallog_m ( const mat &Val ) const {
130                vec x=zeros ( Val.cols() );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
133                }
134                return log ( x );
135        };
136        //! Auxiliary function that returns pdflog for each component
137        mat evallog_M ( const mat &Val ) const {
138                mat X ( w.length(), Val.cols() );
139                for ( int i = 0; i < w.length(); i++ ) {
140                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
141                }
142                return X;
143        };
144
145        emix* marginal ( const RV &rv ) const;
146        mratio* condition ( const RV &rv ) const; //why not mratio!!
147
148//Access methods
149        //! returns a pointer to the internal mean value. Use with Care!
150        vec& _w() {return w;}
151        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
152        //! Auxiliary function for taking ownership of the Coms()
153        void ownComs() {destroyComs=true;}
154
155        //!access function
156        epdf* _Coms ( int i ) {return Coms ( i );}
157        void set_rv(const RV &rv){
158                epdf::set_rv(rv);
159                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
160        }
161};
162
163
164/*!
165* \brief Mixture of egiws
166
167*/
168class egiwmix : public egiw {
169protected:
170        //! weights of the components
171        vec w;
172        //! Component (epdfs)
173        Array<egiw*> Coms;
174        //!Flag if owning Coms
175        bool destroyComs;
176public:
177        //!Default constructor
178        egiwmix ( ) : egiw ( ) {};
179
180        //! Set weights \c w and components \c Coms
181        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
182        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
183
184        //!return expected value
185        vec mean() const;
186
187        //!return a sample from the density
188        vec sample() const;
189
190        //!return the expected variance
191        vec variance() const;
192
193        // TODO!!! Defined to follow ANSI and/or for future development
194        void mean_mat ( mat &M, mat&R ) const {};
195        double evallog_nn ( const vec &val ) const {return 0;};
196        double lognc () const {return 0;};
197        emix* marginal ( const RV &rv ) const;
198
199//Access methods
200        //! returns a pointer to the internal mean value. Use with Care!
201        vec& _w() {return w;}
202        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
203        //! Auxiliary function for taking ownership of the Coms()
204        void ownComs() {destroyComs=true;}
205
206        //!access function
207        egiw* _Coms ( int i ) {return Coms ( i );}
208
209        void set_rv(const RV &rv){
210                egiw::set_rv(rv);
211                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
212        }
213
214        //! Approximation of a GiW mix by a single GiW pdf
215        egiw* approx();
216};
217
218/*! \brief Chain rule decomposition of epdf
219
220Probability density in the form of Chain-rule decomposition:
221\[
222f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
223\]
224Note that
225*/
226class mprod: public compositepdf, public mpdf {
227protected:
228        //! pointers to epdfs - shortcut to mpdfs().posterior()
229        Array<epdf*> epdfs;
230        //! Data link for each mpdfs
231        Array<datalink_m2m*> dls;
232        //! dummy ep
233        epdf dummy;
234public:
235        /*!\brief Constructor from list of mFacs,
236        */
237        mprod ( Array<mpdf*> mFacs ) : compositepdf ( mFacs ), mpdf (), epdfs ( n ), dls ( n ) {
238                ep=&dummy;
239                RV rv=getrv ( true );
240                set_rv ( rv );dummy.set_parameters ( rv._dsize() );
241                setrvc ( ep->_rv(),rvc );
242                // rv and rvc established = > we can link them with mpdfs
243                for ( int i = 0;i < n;i++ ) {
244                        dls ( i ) = new datalink_m2m;
245                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
246                }
247
248                for ( int i=0;i<n;i++ ) {
249                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
250                }
251        };
252
253        double evallogcond ( const vec &val, const vec &cond ) {
254                int i;
255                double res = 0.0;
256                for ( i = n - 1;i >= 0;i-- ) {
257                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
258                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
259                                                }
260                                                // add logarithms
261                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
262                        res += mpdfs ( i )->evallogcond (
263                                   dls ( i )->pushdown ( val ),
264                                   dls ( i )->get_cond ( val, cond )
265                               );
266                }
267                return res;
268        }
269        //TODO smarter...
270        vec samplecond ( const vec &cond ) {
271                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
272                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
273                vec smpi;
274                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
275                for ( int i = ( n - 1 );i >= 0;i-- ) {
276                        if ( mpdfs ( i )->dimensionc() ) {
277                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
278                        }
279                        smpi = epdfs ( i )->sample();
280                        // copy contribution of this pdf into smp
281                        dls ( i )->pushup ( smp, smpi );
282                        // add ith likelihood contribution
283                }
284                return smp;
285        }
286        mat samplecond ( const vec &cond,  int N ) {
287                mat Smp ( dimension(),N );
288                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
289                return Smp;
290        }
291
292        ~mprod() {};
293};
294
295//! Product of independent epdfs. For dependent pdfs, use mprod.
296class eprod: public epdf {
297protected:
298        //! Components (epdfs)
299        Array<const epdf*> epdfs;
300        //! Array of indeces
301        Array<datalink*> dls;
302public:
303        eprod () : epdfs ( 0 ),dls ( 0 ) {};
304        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
305                epdfs=epdfs0;//.set_length ( epdfs0.length() );
306                dls.set_length ( epdfs.length() );
307
308                bool independent=true;
309                if ( named ) {
310                        for ( int i=0;i<epdfs.length();i++ ) {
311                                independent=rv.add ( epdfs ( i )->_rv() );
312                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
313                        }
314                        dim=rv._dsize();
315                }
316                else {
317                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
318                                dim+=epdfs ( i )->dimension();
319                        }
320                }
321                //
322                int cumdim=0;
323                int dimi=0;
324                int i;
325                for ( i=0;i<epdfs.length();i++ ) {
326                        dls ( i ) = new datalink;
327                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
328                        else {
329                                dimi = epdfs ( i )->dimension();
330                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
331                                cumdim+=dimi;
332                        }
333                }
334        }
335
336        vec mean() const {
337                vec tmp ( dim );
338                for ( int i=0;i<epdfs.length();i++ ) {
339                        vec pom = epdfs ( i )->mean();
340                        dls ( i )->pushup ( tmp, pom );
341                }
342                return tmp;
343        }
344        vec variance() const {
345                vec tmp ( dim ); //second moment
346                for ( int i=0;i<epdfs.length();i++ ) {
347                        vec pom = epdfs ( i )->mean();
348                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
349                }
350                return tmp-pow ( mean(),2 );
351        }
352        vec sample() const {
353                vec tmp ( dim );
354                for ( int i=0;i<epdfs.length();i++ ) {
355                        vec pom = epdfs ( i )->sample();
356                        dls ( i )->pushup ( tmp, pom );
357                }
358                return tmp;
359        }
360        double evallog ( const vec &val ) const {
361                double tmp=0;
362                for ( int i=0;i<epdfs.length();i++ ) {
363                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
364                }
365                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
366                return tmp;
367        }
368        //!access function
369        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
370
371        //!Destructor
372        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
373};
374
375
376/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
377
378*/
379class mmix : public mpdf {
380protected:
381        //! Component (epdfs)
382        Array<mpdf*> Coms;
383        //!Internal epdf
384        emix Epdf;
385public:
386        //!Default constructor
387        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
388        //! Set weights \c w and components \c R
389        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
390                Array<epdf*> Eps ( Coms.length() );
391
392                for ( int i = 0;i < Coms.length();i++ ) {
393                        Eps ( i ) = & ( Coms ( i )->_epdf() );
394                }
395                Epdf.set_parameters ( w, Eps );
396        };
397
398        void condition ( const vec &cond ) {
399                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
400        };
401};
402
403}
404#endif //MX_H
Note: See TracBrowser for help on using the browser.