root/bdm/stat/emix.h @ 286

Revision 286, 9.6 kB (checked in by smidl, 15 years ago)

make mpdm work again

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MX_H
14#define MX_H
15
16#include "libBM.h"
17#include "libEF.h"
18//#include <std>
19
20namespace bdm {
21
22//this comes first because it is used inside emix!
23
24/*! \brief Class representing ratio of two densities
25which arise e.g. by applying the Bayes rule.
26It represents density in the form:
27\f[
28f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
29\f]
30where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
31
32In particular this type of arise by conditioning of a mixture model.
33
34At present the only supported operation is evallogcond().
35 */
36class mratio: public mpdf {
37protected:
38        //! Nominator in the form of mpdf
39        const epdf* nom;
40        //!Denominator in the form of epdf
41        epdf* den;
42        //!flag for destructor
43        bool destroynom;
44        //!datalink between conditional and nom
45        datalink_m2e dl;
46public:
47        //!Default constructor. By default, the given epdf is not copied!
48        //! It is assumed that this function will be used only temporarily.
49        mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
50                // adjust rv and rvc
51                rvc = nom0->_rv().subt ( rv );
52                dimc = rvc._dsize();
53                ep = new epdf;
54                ep->set_parameters(rv._dsize());
55                ep->set_rv(rv);
56               
57                //prepare data structures
58                if ( copy ) {it_error ( "todo" ); destroynom=true; }
59                else { nom = nom0; destroynom = false; }
60                it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
61               
62                // build denominator
63                den = nom->marginal ( rvc );
64                dl.set_connection(rv,rvc,nom0->_rv());
65        };
66        double evallogcond ( const vec &val, const vec &cond ) {
67                double tmp;
68                vec nom_val ( ep->dimension() + dimc );
69                dl.pushup_cond ( nom_val,val,cond );
70                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
71                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
72                return tmp;
73        }
74        //! Object takes ownership of nom and will destroy it
75        void ownnom() {destroynom=true;}
76        //! Default destructor
77        ~mratio() {delete den; if ( destroynom ) {delete nom;}}
78};
79
80/*!
81* \brief Mixture of epdfs
82
83Density function:
84\f[
85f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
86\f]
87where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
88
89*/
90class emix : public epdf {
91protected:
92        //! weights of the components
93        vec w;
94        //! Component (epdfs)
95        Array<epdf*> Coms;
96        //!Flag if owning Coms
97        bool destroyComs;
98public:
99        //!Default constructor
100        emix ( ) : epdf ( ) {};
101        //! Set weights \c w and components \c Coms
102        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
103        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
104
105        vec sample() const;
106        vec mean() const {
107                int i; vec mu = zeros ( dim );
108                for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
109                return mu;
110        }
111        vec variance() const {
112                //non-central moment
113                vec mom2 = zeros ( dim );
114                for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * pow ( Coms ( i )->mean(),2 ); }
115                //central moment
116                return mom2-pow ( mean(),2 );
117        }
118        double evallog ( const vec &val ) const {
119                int i;
120                double sum = 0.0;
121                for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
122                if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
123                double tmp=log ( sum );
124                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
125                return tmp;
126        };
127        vec evallog_m ( const mat &Val ) const {
128                vec x=zeros ( Val.cols() );
129                for ( int i = 0; i < w.length(); i++ ) {
130                        x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
131                }
132                return log ( x );
133        };
134        //! Auxiliary function that returns pdflog for each component
135        mat evallog_M ( const mat &Val ) const {
136                mat X ( w.length(), Val.cols() );
137                for ( int i = 0; i < w.length(); i++ ) {
138                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
139                }
140                return X;
141        };
142
143        emix* marginal ( const RV &rv ) const;
144        mratio* condition ( const RV &rv ) const; //why not mratio!!
145
146//Access methods
147        //! returns a pointer to the internal mean value. Use with Care!
148        vec& _w() {return w;}
149        virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
150        //! Auxiliary function for taking ownership of the Coms()
151        void ownComs() {destroyComs=true;}
152
153        //!access function
154        epdf* _Coms ( int i ) {return Coms ( i );}
155        void set_rv(const RV &rv){
156                epdf::set_rv(rv);
157                for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
158        }
159};
160
161/*! \brief Chain rule decomposition of epdf
162
163Probability density in the form of Chain-rule decomposition:
164\[
165f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
166\]
167Note that
168*/
169class mprod: public compositepdf, public mpdf {
170protected:
171        //! pointers to epdfs - shortcut to mpdfs().posterior()
172        Array<epdf*> epdfs;
173        //! Data link for each mpdfs
174        Array<datalink_m2m*> dls;
175        //! dummy ep
176        epdf dummy;
177public:
178        /*!\brief Constructor from list of mFacs,
179        */
180        mprod ( Array<mpdf*> mFacs ) : compositepdf ( mFacs ), mpdf (), epdfs ( n ), dls ( n ) {
181                ep=&dummy;
182                RV rv=getrv ( true );
183                set_rv ( rv );dummy.set_parameters ( rv._dsize() );
184                setrvc ( ep->_rv(),rvc );
185                // rv and rvc established = > we can link them with mpdfs
186                for ( int i = 0;i < n;i++ ) {
187                        dls ( i ) = new datalink_m2m;
188                        dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
189                }
190
191                for ( int i=0;i<n;i++ ) {
192                        epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
193                }
194        };
195
196        double evallogcond ( const vec &val, const vec &cond ) {
197                int i;
198                double res = 0.0;
199                for ( i = n - 1;i >= 0;i-- ) {
200                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
201                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
202                                                }
203                                                // add logarithms
204                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
205                        res += mpdfs ( i )->evallogcond (
206                                   dls ( i )->pushdown ( val ),
207                                   dls ( i )->get_cond ( val, cond )
208                               );
209                }
210                return res;
211        }
212        //TODO smarter...
213        vec samplecond ( const vec &cond ) {
214                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
215                vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
216                vec smpi;
217                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
218                for ( int i = ( n - 1 );i >= 0;i-- ) {
219                        if ( mpdfs ( i )->dimensionc() ) {
220                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); // smp is val here!!
221                        }
222                        smpi = epdfs ( i )->sample();
223                        // copy contribution of this pdf into smp
224                        dls ( i )->pushup ( smp, smpi );
225                        // add ith likelihood contribution
226                }
227                return smp;
228        }
229        mat samplecond ( const vec &cond,  int N ) {
230                mat Smp ( dimension(),N );
231                for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
232                return Smp;
233        }
234
235        ~mprod() {};
236};
237
238//! Product of independent epdfs. For dependent pdfs, use mprod.
239class eprod: public epdf {
240protected:
241        //! Components (epdfs)
242        Array<const epdf*> epdfs;
243        //! Array of indeces
244        Array<datalink*> dls;
245public:
246        eprod () : epdfs ( 0 ),dls ( 0 ) {};
247        void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
248                epdfs=epdfs0;//.set_length ( epdfs0.length() );
249                dls.set_length ( epdfs.length() );
250
251                bool independent=true;
252                if ( named ) {
253                        for ( int i=0;i<epdfs.length();i++ ) {
254                                independent=rv.add ( epdfs ( i )->_rv() );
255                                it_assert_debug ( independent==true, "eprod:: given components are not independent." );
256                        }
257                        dim=rv._dsize();
258                }
259                else {
260                        dim =0; for ( int i=0;i<epdfs.length();i++ ) {
261                                dim+=epdfs ( i )->dimension();
262                        }
263                }
264                //
265                int cumdim=0;
266                int dimi=0;
267                int i;
268                for ( i=0;i<epdfs.length();i++ ) {
269                        dls ( i ) = new datalink;
270                        if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
271                        else {
272                                dimi = epdfs ( i )->dimension();
273                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
274                                cumdim+=dimi;
275                        }
276                }
277        }
278
279        vec mean() const {
280                vec tmp ( dim );
281                for ( int i=0;i<epdfs.length();i++ ) {
282                        vec pom = epdfs ( i )->mean();
283                        dls ( i )->pushup ( tmp, pom );
284                }
285                return tmp;
286        }
287        vec variance() const {
288                vec tmp ( dim ); //second moment
289                for ( int i=0;i<epdfs.length();i++ ) {
290                        vec pom = epdfs ( i )->mean();
291                        dls ( i )->pushup ( tmp, pow ( pom,2 ) );
292                }
293                return tmp-pow ( mean(),2 );
294        }
295        vec sample() const {
296                vec tmp ( dim );
297                for ( int i=0;i<epdfs.length();i++ ) {
298                        vec pom = epdfs ( i )->sample();
299                        dls ( i )->pushup ( tmp, pom );
300                }
301                return tmp;
302        }
303        double evallog ( const vec &val ) const {
304                double tmp=0;
305                for ( int i=0;i<epdfs.length();i++ ) {
306                        tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
307                }
308                it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
309                return tmp;
310        }
311        //!access function
312        const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
313
314        //!Destructor
315        ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
316};
317
318
319/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
320
321*/
322class mmix : public mpdf {
323protected:
324        //! Component (epdfs)
325        Array<mpdf*> Coms;
326        //!Internal epdf
327        emix Epdf;
328public:
329        //!Default constructor
330        mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
331        //! Set weights \c w and components \c R
332        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
333                Array<epdf*> Eps ( Coms.length() );
334
335                for ( int i = 0;i < Coms.length();i++ ) {
336                        Eps ( i ) = & ( Coms ( i )->_epdf() );
337                }
338                Epdf.set_parameters ( w, Eps );
339        };
340
341        void condition ( const vec &cond ) {
342                for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
343        };
344};
345
346}
347#endif //MX_H
Note: See TracBrowser for help on using the browser.