root/library/bdm/stat/emix.h @ 423

Revision 395, 12.0 kB (checked in by smidl, 15 years ago)

merging works for merger_mx

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995 
17
18#include "exp_family.h"
19
20namespace bdm {
21
22//this comes first because it is used inside emix!
23
24/*! \brief Class representing ratio of two densities
25which arise e.g. by applying the Bayes rule.
26It represents density in the form:
27\f[
28f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
29\f]
30where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
31
32In particular this type of arise by conditioning of a mixture model.
33
34At present the only supported operation is evallogcond().
35 */
36class mratio: public mpdf {
37protected:
38        //! Nominator in the form of mpdf
39        const epdf* nom;
40        //!Denominator in the form of epdf
41        epdf* den;
42        //!flag for destructor
43        bool destroynom;
44        //!datalink between conditional and nom
45        datalink_m2e dl;
46public:
47        //!Default constructor. By default, the given epdf is not copied!
48        //! It is assumed that this function will be used only temporarily.
49        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
50                // adjust rv and rvc
51                rvc = nom0->_rv().subt ( rv );
52                dimc = rvc._dsize();
53                ep = new epdf;
54                ep->set_parameters(rv._dsize());
55                ep->set_rv(rv);
56               
57                //prepare data structures
58                if ( copy ) {it_error ( "todo" ); destroynom=true; }
59                else { nom = nom0; destroynom = false; }
60                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
61               
62                // build denominator
63                den = nom->marginal ( rvc );
64                dl.set_connection(rv,rvc,nom0->_rv());
65        };
66        double evallogcond ( const vec &val, const vec &cond ) {
67                double tmp;
68                vec nom_val ( ep->dimension() + dimc );
69                dl.pushup_cond ( nom_val,val,cond );
70                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
71                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
72                return tmp;
73        }
74        //! Object takes ownership of nom and will destroy it
75        void ownnom() {destroynom=true;}
76        //! Default destructor
77        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
78};
79
80/*!
81* \brief Mixture of epdfs
82
83Density function:
84\f[
85f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
86\f]
87where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
88
89*/
90class emix : public epdf {
91protected:
92        //! weights of the components
93        vec w;
94        //! Component (epdfs)
95        Array<epdf*> Coms;
96        //!Flag if owning Coms
97        bool destroyComs;
98public:
99        //!Default constructor
100        emix ( ) : epdf ( ) {};
101        //! Set weights \c w and components \c Coms
102        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
103        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
104
105        vec sample() const;
106        vec mean() const {
107                int i; vec mu = zeros ( dim );
108                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
109                return mu;
110        }
111        vec variance() const {
112                //non-central moment
113                vec mom2 = zeros ( dim );
114                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
115                //central moment
116                return mom2-pow ( mean(),2 );
117        }
118        double evallog ( const vec &val ) const {
119                int i;
120                double sum = 0.0;
121                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
122                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
123                double tmp=log ( sum );
124                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
125                return tmp;
126        };
127        vec evallog_m ( const mat &Val ) const {
128                vec x=zeros ( Val.cols() );
129                for ( int i = 0; i < w.length(); i++ ) {
130                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
131                }
132                return log ( x );
133        };
134        //! Auxiliary function that returns pdflog for each component
135        mat evallog_M ( const mat &Val ) const {
136                mat X ( w.length(), Val.cols() );
137                for ( int i = 0; i < w.length(); i++ ) {
138                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
139                }
140                return X;
141        };
142
143        emix* marginal ( const RV &rv ) const;
144        mratio* condition ( const RV &rv ) const; //why not mratio!!
145
146//Access methods
147        //! returns a pointer to the internal mean value. Use with Care!
148        vec& _w() {return w;}
149        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
150        //! Auxiliary function for taking ownership of the Coms()
151        void ownComs() {destroyComs=true;}
152
153        //!access function
154        epdf* _Coms ( int i ) {return Coms ( i );}
155        void set_rv(const RV &rv){
156                epdf::set_rv(rv);
157                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
158        }
159};
160
161
162/*!
163* \brief Mixture of egiws
164
165*/
166class egiwmix : public egiw {
167protected:
168        //! weights of the components
169        vec w;
170        //! Component (epdfs)
171        Array<egiw*> Coms;
172        //!Flag if owning Coms
173        bool destroyComs;
174public:
175        //!Default constructor
176        egiwmix ( ) : egiw ( ) {};
177
178        //! Set weights \c w and components \c Coms
179        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
180        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
181
182        //!return expected value
183        vec mean() const;
184
185        //!return a sample from the density
186        vec sample() const;
187
188        //!return the expected variance
189        vec variance() const;
190
191        // TODO!!! Defined to follow ANSI and/or for future development
192        void mean_mat ( mat &M, mat&R ) const {};
193        double evallog_nn ( const vec &val ) const {return 0;};
194        double lognc () const {return 0;};
195        emix* marginal ( const RV &rv ) const;
196
197//Access methods
198        //! returns a pointer to the internal mean value. Use with Care!
199        vec& _w() {return w;}
200        virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
201        //! Auxiliary function for taking ownership of the Coms()
202        void ownComs() {destroyComs=true;}
203
204        //!access function
205        egiw* _Coms ( int i ) {return Coms ( i );}
206
207        void set_rv(const RV &rv){
208                egiw::set_rv(rv);
209                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
210        }
211
212        //! Approximation of a GiW mix by a single GiW pdf
213        egiw* approx();
214};
215
216/*! \brief Chain rule decomposition of epdf
217
218Probability density in the form of Chain-rule decomposition:
219\[
220f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
221\]
222Note that
223*/
224class mprod: public compositepdf, public mpdf {
225protected:
226        //! pointers to epdfs - shortcut to mpdfs().posterior()
227        Array<epdf*> epdfs;
228        //! Data link for each mpdfs
229        Array<datalink_m2m*> dls;
230        //! dummy ep
231        epdf dummy;
232public:
233        /*!\brief Constructor from list of mFacs,
234        */
235        mprod (){};
236        mprod (Array<mpdf*> mFacs ){set_elements( mFacs );};
237        void set_elements(Array<mpdf*> mFacs , bool own=false) {
238               
239                compositepdf::set_elements(mFacs,own);
240                dls.set_size(mFacs.length());
241                epdfs.set_size(mFacs.length());
242                               
243                ep=&dummy;
244                RV rv=getrv ( true );
245                set_rv ( rv );
246                dummy.set_parameters ( rv._dsize() );
247                setrvc ( ep->_rv(),rvc );
248                // rv and rvc established = > we can link them with mpdfs
249                for ( int i = 0;i < mpdfs.length(); i++ ) {
250                        dls ( i ) = new datalink_m2m;
251                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
252                }
253
254                for ( int i=0; i<mpdfs.length(); i++ ) {
255                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
256                }
257        };
258
259        double evallogcond ( const vec &val, const vec &cond ) {
260                int i;
261                double res = 0.0;
262                for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
263                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
264                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
265                                                }
266                                                // add logarithms
267                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
268                        res += mpdfs ( i )->evallogcond (
269                                   dls ( i )->pushdown ( val ),
270                                   dls ( i )->get_cond ( val, cond )
271                               );
272                }
273                return res;
274        }
275        vec evallogcond_m(const mat &Dt, const vec &cond) {
276                vec tmp(Dt.cols());
277                for(int i=0;i<Dt.cols(); i++){
278                        tmp(i) = evallogcond(Dt.get_col(i),cond);
279                }
280                return tmp;
281        };
282        vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {
283                vec tmp(Dt.length());
284                for(int i=0;i<Dt.length(); i++){
285                        tmp(i) = evallogcond(Dt(i),cond);
286                }
287                return tmp;             
288        };
289
290
291        //TODO smarter...
292        vec samplecond ( const vec &cond ) {
293                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
294                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
295                vec smpi;
296                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
297                for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
298                        if ( mpdfs ( i )->dimensionc() ) {
299                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
300                        }
301                        smpi = epdfs ( i )->sample();
302                        // copy contribution of this pdf into smp
303                        dls ( i )->pushup ( smp, smpi );
304                        // add ith likelihood contribution
305                }
306                return smp;
307        }
308        mat samplecond ( const vec &cond,  int N ) {
309                mat Smp ( dimension(),N );
310                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
311                return Smp;
312        }
313
314        ~mprod() {};
315        //! Load from structure with elements:
316        //!  \code
317        //! { class='mprod';
318        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
319        //! }
320        //! \endcode
321        //!@}
322        void from_setting(const Setting &set){
323                Array<mpdf*> Atmp; //temporary Array
324                UI::get(Atmp,set, "mpdfs");
325                set_elements(Atmp,true);
326        }
327       
328};
329UIREGISTER(mprod);
330
331//! Product of independent epdfs. For dependent pdfs, use mprod.
332class eprod: public epdf {
333protected:
334        //! Components (epdfs)
335        Array<const epdf*> epdfs;
336        //! Array of indeces
337        Array<datalink*> dls;
338public:
339        eprod () : epdfs ( 0 ),dls ( 0 ) {};
340        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
341                epdfs=epdfs0;//.set_length ( epdfs0.length() );
342                dls.set_length ( epdfs.length() );
343
344                bool independent=true;
345                if ( named ) {
346                        for ( int i=0;i<epdfs.length();i++ ) {
347                                independent=rv.add ( epdfs ( i )->_rv() );
348                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
349                        }
350                        dim=rv._dsize();
351                }
352                else {
353                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
354                                dim+=epdfs ( i )->dimension();
355                        }
356                }
357                //
358                int cumdim=0;
359                int dimi=0;
360                int i;
361                for ( i=0;i<epdfs.length();i++ ) {
362                        dls ( i ) = new datalink;
363                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
364                        else {
365                                dimi = epdfs ( i )->dimension();
366                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
367                                cumdim+=dimi;
368                        }
369                }
370        }
371
372        vec mean() const {
373                vec tmp ( dim );
374                for ( int i=0;i<epdfs.length();i++ ) {
375                        vec pom = epdfs ( i )->mean();
376                        dls ( i )->pushup ( tmp, pom );
377                }
378                return tmp;
379        }
380        vec variance() const {
381                vec tmp ( dim ); //second moment
382                for ( int i=0;i<epdfs.length();i++ ) {
383                        vec pom = epdfs ( i )->mean();
384                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
385                }
386                return tmp-pow ( mean(),2 );
387        }
388        vec sample() const {
389                vec tmp ( dim );
390                for ( int i=0;i<epdfs.length();i++ ) {
391                        vec pom = epdfs ( i )->sample();
392                        dls ( i )->pushup ( tmp, pom );
393                }
394                return tmp;
395        }
396        double evallog ( const vec &val ) const {
397                double tmp=0;
398                for ( int i=0;i<epdfs.length();i++ ) {
399                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
400                }
401                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
402                return tmp;
403        }
404        //!access function
405        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
406
407        //!Destructor
408        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
409};
410
411
412/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
413
414*/
415class mmix : public mpdf {
416protected:
417        //! Component (epdfs)
418        Array<mpdf*> Coms;
419        //!Internal epdf
420        emix Epdf;
421public:
422        //!Default constructor
423        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
424        //! Set weights \c w and components \c R
425        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
426                Array<epdf*> Eps ( Coms.length() );
427
428                for ( int i = 0;i < Coms.length();i++ ) {
429                        Eps ( i ) = & ( Coms ( i )->_epdf() );
430                }
431                Epdf.set_parameters ( w, Eps );
432        };
433
434        void condition ( const vec &cond ) {
435                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
436        };
437};
438
439}
440#endif //MX_H
Note: See TracBrowser for help on using the browser.