root/library/bdm/stat/mixtures.h @ 384

Revision 384, 11.2 kB (checked in by mido, 15 years ago)

possibly broken?

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef STAT_MIXTURES_H
14#define STAT_MIXTURES_H
15
16#define LOG2  0.69314718055995 
17
18#include "exp_family.h"
19
20namespace bdm {
21
22//this comes first because it is used inside emix!
23
24/*! \brief Class representing ratio of two densities
25which arise e.g. by applying the Bayes rule.
26It represents density in the form:
27\f[
28f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
29\f]
30where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
31
32In particular this type of arise by conditioning of a mixture model.
33
34At present the only supported operation is evallogcond().
35 */
36class mratio: public mpdf {
37protected:
38        //! Nominator in the form of mpdf
39        const epdf* nom;
40        //!Denominator in the form of epdf
41        epdf* den;
42        //!flag for destructor
43        bool destroynom;
44        //!datalink between conditional and nom
45        datalink_m2e dl;
46public:
47        //!Default constructor. By default, the given epdf is not copied!
48        //! It is assumed that this function will be used only temporarily.
49        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
50                // adjust rv and rvc
51                rvc = nom0->_rv().subt ( rv );
52                dimc = rvc._dsize();
53                ep = new epdf;
54                ep->set_parameters(rv._dsize());
55                ep->set_rv(rv);
56               
57                //prepare data structures
58                if ( copy ) {it_error ( "todo" ); destroynom=true; }
59                else { nom = nom0; destroynom = false; }
60                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
61               
62                // build denominator
63                den = nom->marginal ( rvc );
64                dl.set_connection(rv,rvc,nom0->_rv());
65        };
66        double evallogcond ( const vec &val, const vec &cond ) {
67                double tmp;
68                vec nom_val ( ep->dimension() + dimc );
69                dl.pushup_cond ( nom_val,val,cond );
70                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
71                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
72                return tmp;
73        }
74        //! Object takes ownership of nom and will destroy it
75        void ownnom() {destroynom=true;}
76        //! Default destructor
77        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
78};
79
80/*!
81* \brief Mixture of epdfs
82
83Density function:
84\f[
85f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
86\f]
87where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
88
89*/
90class emix : public epdf {
91protected:
92        //! weights of the components
93        vec w;
94        //! Component (epdfs)
95        Array<epdf*> Coms;
96        //!Flag if owning Coms
97        bool destroyComs;
98public:
99        //!Default constructor
100        emix ( ) : epdf ( ) {};
101        //! Set weights \c w and components \c Coms
102        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
103        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
104
105        vec sample() const;
106        vec mean() const {
107                int i; vec mu = zeros ( dim );
108                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
109                return mu;
110        }
111        vec variance() const {
112                //non-central moment
113                vec mom2 = zeros ( dim );
114                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
115                //central moment
116                return mom2-pow ( mean(),2 );
117        }
118        double evallog ( const vec &val ) const {
119                int i;
120                double sum = 0.0;
121                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
122                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
123                double tmp=log ( sum );
124                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
125                return tmp;
126        };
127        vec evallog_m ( const mat &Val ) const {
128                vec x=zeros ( Val.cols() );
129                for ( int i = 0; i < w.length(); i++ ) {
130                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
131                }
132                return log ( x );
133        };
134        //! Auxiliary function that returns pdflog for each component
135        mat evallog_M ( const mat &Val ) const {
136                mat X ( w.length(), Val.cols() );
137                for ( int i = 0; i < w.length(); i++ ) {
138                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
139                }
140                return X;
141        };
142
143        emix* marginal ( const RV &rv ) const;
144        mratio* condition ( const RV &rv ) const; //why not mratio!!
145
146//Access methods
147        //! returns a pointer to the internal mean value. Use with Care!
148        vec& _w() {return w;}
149        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
150        //! Auxiliary function for taking ownership of the Coms()
151        void ownComs() {destroyComs=true;}
152
153        //!access function
154        epdf* _Coms ( int i ) {return Coms ( i );}
155        void set_rv(const RV &rv){
156                epdf::set_rv(rv);
157                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
158        }
159};
160
161
162/*!
163* \brief Mixture of egiws
164
165*/
166class egiwmix : public egiw {
167protected:
168        //! weights of the components
169        vec w;
170        //! Component (epdfs)
171        Array<egiw*> Coms;
172        //!Flag if owning Coms
173        bool destroyComs;
174public:
175        //!Default constructor
176        egiwmix ( ) : egiw ( ) {};
177
178        //! Set weights \c w and components \c Coms
179        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
180        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
181
182        //!return expected value
183        vec mean() const;
184
185        //!return a sample from the density
186        vec sample() const;
187
188        //!return the expected variance
189        vec variance() const;
190
191        // TODO!!! Defined to follow ANSI and/or for future development
192        void mean_mat ( mat &M, mat&R ) const {};
193        double evallog_nn ( const vec &val ) const {return 0;};
194        double lognc () const {return 0;};
195        emix* marginal ( const RV &rv ) const;
196
197//Access methods
198        //! returns a pointer to the internal mean value. Use with Care!
199        vec& _w() {return w;}
200        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
201        //! Auxiliary function for taking ownership of the Coms()
202        void ownComs() {destroyComs=true;}
203
204        //!access function
205        egiw* _Coms ( int i ) {return Coms ( i );}
206
207        void set_rv(const RV &rv){
208                egiw::set_rv(rv);
209                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
210        }
211
212        //! Approximation of a GiW mix by a single GiW pdf
213        egiw* approx();
214};
215
216/*! \brief Chain rule decomposition of epdf
217
218Probability density in the form of Chain-rule decomposition:
219\[
220f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
221\]
222Note that
223*/
224class mprod: public compositepdf, public mpdf {
225protected:
226        //! pointers to epdfs - shortcut to mpdfs().posterior()
227        Array<epdf*> epdfs;
228        //! Data link for each mpdfs
229        Array<datalink_m2m*> dls;
230        //! dummy ep
231        epdf dummy;
232public:
233        /*!\brief Constructor from list of mFacs,
234        */
235        mprod (){};
236        mprod (Array<mpdf*> mFacs ){set_elements( mFacs );};
237        void set_elements(Array<mpdf*> mFacs ) {
238               
239                set_elements(mFacs);
240               
241                ep=&dummy;
242                RV rv=getrv ( true );
243                set_rv ( rv );dummy.set_parameters ( rv._dsize() );
244                setrvc ( ep->_rv(),rvc );
245                // rv and rvc established = > we can link them with mpdfs
246                for ( int i = 0;i < mpdfs.length(); i++ ) {
247                        dls ( i ) = new datalink_m2m;
248                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
249                }
250
251                for ( int i=0; i<mpdfs.length(); i++ ) {
252                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
253                }
254        };
255
256        double evallogcond ( const vec &val, const vec &cond ) {
257                int i;
258                double res = 0.0;
259                for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
260                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
261                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
262                                                }
263                                                // add logarithms
264                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
265                        res += mpdfs ( i )->evallogcond (
266                                   dls ( i )->pushdown ( val ),
267                                   dls ( i )->get_cond ( val, cond )
268                               );
269                }
270                return res;
271        }
272        //TODO smarter...
273        vec samplecond ( const vec &cond ) {
274                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
275                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
276                vec smpi;
277                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
278                for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
279                        if ( mpdfs ( i )->dimensionc() ) {
280                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
281                        }
282                        smpi = epdfs ( i )->sample();
283                        // copy contribution of this pdf into smp
284                        dls ( i )->pushup ( smp, smpi );
285                        // add ith likelihood contribution
286                }
287                return smp;
288        }
289        mat samplecond ( const vec &cond,  int N ) {
290                mat Smp ( dimension(),N );
291                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
292                return Smp;
293        }
294
295        ~mprod() {};
296};
297
298//! Product of independent epdfs. For dependent pdfs, use mprod.
299class eprod: public epdf {
300protected:
301        //! Components (epdfs)
302        Array<const epdf*> epdfs;
303        //! Array of indeces
304        Array<datalink*> dls;
305public:
306        eprod () : epdfs ( 0 ),dls ( 0 ) {};
307        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
308                epdfs=epdfs0;//.set_length ( epdfs0.length() );
309                dls.set_length ( epdfs.length() );
310
311                bool independent=true;
312                if ( named ) {
313                        for ( int i=0;i<epdfs.length();i++ ) {
314                                independent=rv.add ( epdfs ( i )->_rv() );
315                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
316                        }
317                        dim=rv._dsize();
318                }
319                else {
320                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
321                                dim+=epdfs ( i )->dimension();
322                        }
323                }
324                //
325                int cumdim=0;
326                int dimi=0;
327                int i;
328                for ( i=0;i<epdfs.length();i++ ) {
329                        dls ( i ) = new datalink;
330                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
331                        else {
332                                dimi = epdfs ( i )->dimension();
333                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
334                                cumdim+=dimi;
335                        }
336                }
337        }
338
339        vec mean() const {
340                vec tmp ( dim );
341                for ( int i=0;i<epdfs.length();i++ ) {
342                        vec pom = epdfs ( i )->mean();
343                        dls ( i )->pushup ( tmp, pom );
344                }
345                return tmp;
346        }
347        vec variance() const {
348                vec tmp ( dim ); //second moment
349                for ( int i=0;i<epdfs.length();i++ ) {
350                        vec pom = epdfs ( i )->mean();
351                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
352                }
353                return tmp-pow ( mean(),2 );
354        }
355        vec sample() const {
356                vec tmp ( dim );
357                for ( int i=0;i<epdfs.length();i++ ) {
358                        vec pom = epdfs ( i )->sample();
359                        dls ( i )->pushup ( tmp, pom );
360                }
361                return tmp;
362        }
363        double evallog ( const vec &val ) const {
364                double tmp=0;
365                for ( int i=0;i<epdfs.length();i++ ) {
366                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
367                }
368                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
369                return tmp;
370        }
371        //!access function
372        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
373
374        //!Destructor
375        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
376};
377
378
379/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
380
381*/
382class mmix : public mpdf {
383protected:
384        //! Component (epdfs)
385        Array<mpdf*> Coms;
386        //!Internal epdf
387        emix Epdf;
388public:
389        //!Default constructor
390        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
391        //! Set weights \c w and components \c R
392        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
393                Array<epdf*> Eps ( Coms.length() );
394
395                for ( int i = 0;i < Coms.length();i++ ) {
396                        Eps ( i ) = & ( Coms ( i )->_epdf() );
397                }
398                Epdf.set_parameters ( w, Eps );
399        };
400
401        void condition ( const vec &cond ) {
402                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
403        };
404};
405
406}
407#endif //MX_H
Note: See TracBrowser for help on using the browser.