root/bdm/stat/emix.h @ 361

Revision 361, 11.1 kB (checked in by smidl, 15 years ago)
  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MX_H
14#define MX_H
15
16#define LOG2  0.69314718055995 
17
18#include "libEF.h"
19//#include <std>
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41        //!Denominator in the form of epdf
42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
45        //!datalink between conditional and nom
46        datalink_m2e dl;
47public:
48        //!Default constructor. By default, the given epdf is not copied!
49        //! It is assumed that this function will be used only temporarily.
50        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
51                // adjust rv and rvc
52                rvc = nom0->_rv().subt ( rv );
53                dimc = rvc._dsize();
54                ep = new epdf;
55                ep->set_parameters(rv._dsize());
56                ep->set_rv(rv);
57               
58                //prepare data structures
59                if ( copy ) {it_error ( "todo" ); destroynom=true; }
60                else { nom = nom0; destroynom = false; }
61                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
62               
63                // build denominator
64                den = nom->marginal ( rvc );
65                dl.set_connection(rv,rvc,nom0->_rv());
66        };
67        double evallogcond ( const vec &val, const vec &cond ) {
68                double tmp;
69                vec nom_val ( ep->dimension() + dimc );
70                dl.pushup_cond ( nom_val,val,cond );
71                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
72                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
73                return tmp;
74        }
75        //! Object takes ownership of nom and will destroy it
76        void ownnom() {destroynom=true;}
77        //! Default destructor
78        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
79};
80
81/*!
82* \brief Mixture of epdfs
83
84Density function:
85\f[
86f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
87\f]
88where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
89
90*/
91class emix : public epdf {
92protected:
93        //! weights of the components
94        vec w;
95        //! Component (epdfs)
96        Array<epdf*> Coms;
97        //!Flag if owning Coms
98        bool destroyComs;
99public:
100        //!Default constructor
101        emix ( ) : epdf ( ) {};
102        //! Set weights \c w and components \c Coms
103        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
104        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
105
106        vec sample() const;
107        vec mean() const {
108                int i; vec mu = zeros ( dim );
109                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
110                return mu;
111        }
112        vec variance() const {
113                //non-central moment
114                vec mom2 = zeros ( dim );
115                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
116                //central moment
117                return mom2-pow ( mean(),2 );
118        }
119        double evallog ( const vec &val ) const {
120                int i;
121                double sum = 0.0;
122                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
123                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
124                double tmp=log ( sum );
125                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
126                return tmp;
127        };
128        vec evallog_m ( const mat &Val ) const {
129                vec x=zeros ( Val.cols() );
130                for ( int i = 0; i < w.length(); i++ ) {
131                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
132                }
133                return log ( x );
134        };
135        //! Auxiliary function that returns pdflog for each component
136        mat evallog_M ( const mat &Val ) const {
137                mat X ( w.length(), Val.cols() );
138                for ( int i = 0; i < w.length(); i++ ) {
139                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
140                }
141                return X;
142        };
143
144        emix* marginal ( const RV &rv ) const;
145        mratio* condition ( const RV &rv ) const; //why not mratio!!
146
147//Access methods
148        //! returns a pointer to the internal mean value. Use with Care!
149        vec& _w() {return w;}
150        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
151        //! Auxiliary function for taking ownership of the Coms()
152        void ownComs() {destroyComs=true;}
153
154        //!access function
155        epdf* _Coms ( int i ) {return Coms ( i );}
156        void set_rv(const RV &rv){
157                epdf::set_rv(rv);
158                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
159        }
160};
161
162
163/*!
164* \brief Mixture of egiws
165
166*/
167class egiwmix : public egiw {
168protected:
169        //! weights of the components
170        vec w;
171        //! Component (epdfs)
172        Array<egiw*> Coms;
173        //!Flag if owning Coms
174        bool destroyComs;
175public:
176        //!Default constructor
177        egiwmix ( ) : egiw ( ) {};
178
179        //! Set weights \c w and components \c Coms
180        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
181        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
182
183        //!return expected value
184        vec mean() const;
185
186        //!return a sample from the density
187        vec sample() const;
188
189        //!return the expected variance
190        vec variance() const;
191
192        // TODO!!! Defined to follow ANSI and/or for future development
193        void mean_mat ( mat &M, mat&R ) const {};
194        double evallog_nn ( const vec &val ) const {return 0;};
195        double lognc () const {return 0;};
196        emix* marginal ( const RV &rv ) const;
197
198//Access methods
199        //! returns a pointer to the internal mean value. Use with Care!
200        vec& _w() {return w;}
201        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
202        //! Auxiliary function for taking ownership of the Coms()
203        void ownComs() {destroyComs=true;}
204
205        //!access function
206        egiw* _Coms ( int i ) {return Coms ( i );}
207
208        void set_rv(const RV &rv){
209                egiw::set_rv(rv);
210                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
211        }
212
213        //! Approximation of a GiW mix by a single GiW pdf
214        egiw* approx();
215};
216
217/*! \brief Chain rule decomposition of epdf
218
219Probability density in the form of Chain-rule decomposition:
220\[
221f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
222\]
223Note that
224*/
225class mprod: public compositepdf, public mpdf {
226protected:
227        //! pointers to epdfs - shortcut to mpdfs().posterior()
228        Array<epdf*> epdfs;
229        //! Data link for each mpdfs
230        Array<datalink_m2m*> dls;
231        //! dummy ep
232        epdf dummy;
233public:
234        /*!\brief Constructor from list of mFacs,
235        */
236        mprod ( Array<mpdf*> mFacs ) : compositepdf ( mFacs ), mpdf (), epdfs ( n ), dls ( n ) {
237                ep=&dummy;
238                RV rv=getrv ( true );
239                set_rv ( rv );dummy.set_parameters ( rv._dsize() );
240                setrvc ( ep->_rv(),rvc );
241                // rv and rvc established = > we can link them with mpdfs
242                for ( int i = 0;i < n;i++ ) {
243                        dls ( i ) = new datalink_m2m;
244                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
245                }
246
247                for ( int i=0;i<n;i++ ) {
248                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
249                }
250        };
251
252        double evallogcond ( const vec &val, const vec &cond ) {
253                int i;
254                double res = 0.0;
255                for ( i = n - 1;i >= 0;i-- ) {
256                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
257                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
258                                                }
259                                                // add logarithms
260                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
261                        res += mpdfs ( i )->evallogcond (
262                                   dls ( i )->pushdown ( val ),
263                                   dls ( i )->get_cond ( val, cond )
264                               );
265                }
266                return res;
267        }
268        //TODO smarter...
269        vec samplecond ( const vec &cond ) {
270                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
271                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
272                vec smpi;
273                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
274                for ( int i = ( n - 1 );i >= 0;i-- ) {
275                        if ( mpdfs ( i )->dimensionc() ) {
276                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
277                        }
278                        smpi = epdfs ( i )->sample();
279                        // copy contribution of this pdf into smp
280                        dls ( i )->pushup ( smp, smpi );
281                        // add ith likelihood contribution
282                }
283                return smp;
284        }
285        mat samplecond ( const vec &cond,  int N ) {
286                mat Smp ( dimension(),N );
287                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
288                return Smp;
289        }
290
291        ~mprod() {};
292};
293
294//! Product of independent epdfs. For dependent pdfs, use mprod.
295class eprod: public epdf {
296protected:
297        //! Components (epdfs)
298        Array<const epdf*> epdfs;
299        //! Array of indeces
300        Array<datalink*> dls;
301public:
302        eprod () : epdfs ( 0 ),dls ( 0 ) {};
303        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
304                epdfs=epdfs0;//.set_length ( epdfs0.length() );
305                dls.set_length ( epdfs.length() );
306
307                bool independent=true;
308                if ( named ) {
309                        for ( int i=0;i<epdfs.length();i++ ) {
310                                independent=rv.add ( epdfs ( i )->_rv() );
311                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
312                        }
313                        dim=rv._dsize();
314                }
315                else {
316                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
317                                dim+=epdfs ( i )->dimension();
318                        }
319                }
320                //
321                int cumdim=0;
322                int dimi=0;
323                int i;
324                for ( i=0;i<epdfs.length();i++ ) {
325                        dls ( i ) = new datalink;
326                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
327                        else {
328                                dimi = epdfs ( i )->dimension();
329                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
330                                cumdim+=dimi;
331                        }
332                }
333        }
334
335        vec mean() const {
336                vec tmp ( dim );
337                for ( int i=0;i<epdfs.length();i++ ) {
338                        vec pom = epdfs ( i )->mean();
339                        dls ( i )->pushup ( tmp, pom );
340                }
341                return tmp;
342        }
343        vec variance() const {
344                vec tmp ( dim ); //second moment
345                for ( int i=0;i<epdfs.length();i++ ) {
346                        vec pom = epdfs ( i )->mean();
347                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
348                }
349                return tmp-pow ( mean(),2 );
350        }
351        vec sample() const {
352                vec tmp ( dim );
353                for ( int i=0;i<epdfs.length();i++ ) {
354                        vec pom = epdfs ( i )->sample();
355                        dls ( i )->pushup ( tmp, pom );
356                }
357                return tmp;
358        }
359        double evallog ( const vec &val ) const {
360                double tmp=0;
361                for ( int i=0;i<epdfs.length();i++ ) {
362                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
363                }
364                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
365                return tmp;
366        }
367        //!access function
368        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
369
370        //!Destructor
371        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
372};
373
374
375/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
376
377*/
378class mmix : public mpdf {
379protected:
380        //! Component (epdfs)
381        Array<mpdf*> Coms;
382        //!Internal epdf
383        emix Epdf;
384public:
385        //!Default constructor
386        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
387        //! Set weights \c w and components \c R
388        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
389                Array<epdf*> Eps ( Coms.length() );
390
391                for ( int i = 0;i < Coms.length();i++ ) {
392                        Eps ( i ) = & ( Coms ( i )->_epdf() );
393                }
394                Epdf.set_parameters ( w, Eps );
395        };
396
397        void condition ( const vec &cond ) {
398                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
399        };
400};
401
402}
403#endif //MX_H
Note: See TracBrowser for help on using the browser.