root/library/bdm/stat/emix.h @ 504

Revision 504, 12.4 kB (checked in by vbarta, 15 years ago)

returning shared pointers from epdf::marginal & epdf::condition; testsuite run leaks down from 8402 to 6510 bytes

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93};
94
95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
105class emix : public epdf {
106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
110        Array<shared_ptr<epdf> > Coms;
111
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        void marginal ( const RV &rv, emix &target ) const;
168        shared_ptr<mpdf> condition ( const RV &rv ) const;
169
170//Access methods
171        //! returns a pointer to the internal mean value. Use with Care!
172        vec& _w() {
173                return w;
174        }
175
176        //!access function
177        shared_ptr<epdf> _Coms ( int i ) {
178                return Coms ( i );
179        }
180
181        void set_rv ( const RV &rv ) {
182                epdf::set_rv ( rv );
183                for ( int i = 0; i < Coms.length(); i++ ) {
184                        Coms ( i )->set_rv ( rv );
185                }
186        }
187};
188
189
190/*!
191* \brief Mixture of egiws
192
193*/
194class egiwmix : public egiw {
195protected:
196        //! weights of the components
197        vec w;
198        //! Component (epdfs)
199        Array<egiw*> Coms;
200        //!Flag if owning Coms
201        bool destroyComs;
202public:
203        //!Default constructor
204        egiwmix ( ) : egiw ( ) {};
205
206        //! Set weights \c w and components \c Coms
207        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
208        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
209
210        //!return expected value
211        vec mean() const;
212
213        //!return a sample from the density
214        vec sample() const;
215
216        //!return the expected variance
217        vec variance() const;
218
219        // TODO!!! Defined to follow ANSI and/or for future development
220        void mean_mat ( mat &M, mat&R ) const {};
221        double evallog_nn ( const vec &val ) const {
222                return 0;
223        };
224        double lognc () const {
225                return 0;
226        }
227
228        shared_ptr<epdf> marginal ( const RV &rv ) const;
229        void marginal ( const RV &rv, emix &target ) const;
230
231//Access methods
232        //! returns a pointer to the internal mean value. Use with Care!
233        vec& _w() {
234                return w;
235        }
236        virtual ~egiwmix() {
237                if ( destroyComs ) {
238                        for ( int i = 0; i < Coms.length(); i++ ) {
239                                delete Coms ( i );
240                        }
241                }
242        }
243        //! Auxiliary function for taking ownership of the Coms()
244        void ownComs() {
245                destroyComs = true;
246        }
247
248        //!access function
249        egiw* _Coms ( int i ) {
250                return Coms ( i );
251        }
252
253        void set_rv ( const RV &rv ) {
254                egiw::set_rv ( rv );
255                for ( int i = 0; i < Coms.length(); i++ ) {
256                        Coms ( i )->set_rv ( rv );
257                }
258        }
259
260        //! Approximation of a GiW mix by a single GiW pdf
261        egiw* approx();
262};
263
264/*! \brief Chain rule decomposition of epdf
265
266Probability density in the form of Chain-rule decomposition:
267\[
268f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
269\]
270Note that
271*/
272class mprod: public compositepdf, public mpdf {
273protected:
274        //! Data link for each mpdfs
275        Array<datalink_m2m*> dls;
276
277        //! dummy epdf used only as storage for RV and dim
278        epdf iepdf;
279
280public:
281        /*!\brief Constructor from list of mFacs,
282        */
283        mprod() : iepdf( ) { }
284        mprod ( Array<mpdf*> mFacs ) :
285                        iepdf ( ) {
286                set_elements ( mFacs );
287        }
288
289        void set_elements ( Array<mpdf*> mFacs , bool own = false ) {
290
291                compositepdf::set_elements ( mFacs, own );
292                dls.set_size ( mFacs.length() );
293
294                set_ep ( iepdf);
295                RV rv = getrv ( true );
296                set_rv ( rv );
297                iepdf.set_parameters ( rv._dsize() );
298                setrvc (_rv(), rvc );
299                // rv and rvc established = > we can link them with mpdfs
300                for ( int i = 0; i < mpdfs.length(); i++ ) {
301                        dls ( i ) = new datalink_m2m;
302                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
303                }
304
305        };
306
307        double evallogcond ( const vec &val, const vec &cond ) {
308                int i;
309                double res = 0.0;
310                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
311                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
312                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
313                                                }
314                                                // add logarithms
315                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
316                        res += mpdfs ( i )->evallogcond (
317                                   dls ( i )->pushdown ( val ),
318                                   dls ( i )->get_cond ( val, cond )
319                               );
320                }
321                return res;
322        }
323        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
324                vec tmp ( Dt.cols() );
325                for ( int i = 0; i < Dt.cols(); i++ ) {
326                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
327                }
328                return tmp;
329        };
330        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
331                vec tmp ( Dt.length() );
332                for ( int i = 0; i < Dt.length(); i++ ) {
333                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
334                }
335                return tmp;
336        };
337
338
339        //TODO smarter...
340        vec samplecond ( const vec &cond ) {
341                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
342                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
343                vec smpi;
344                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
345                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
346                        // generate contribution of this mpdf
347                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
348                        // copy contribution of this pdf into smp
349                        dls ( i )->pushup ( smp, smpi );
350                }
351                return smp;
352        }
353        mat samplecond ( const vec &cond,  int N ) {
354                mat Smp ( dimension(), N );
355                for ( int i = 0; i < N; i++ ) {
356                        Smp.set_col ( i, samplecond ( cond ) );
357                }
358                return Smp;
359        }
360
361        ~mprod() {};
362        //! Load from structure with elements:
363        //!  \code
364        //! { class='mprod';
365        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
366        //! }
367        //! \endcode
368        //!@}
369        void from_setting ( const Setting &set ) {
370                Array<mpdf*> Atmp; //temporary Array
371                UI::get ( Atmp, set, "mpdfs", UI::compulsory );
372                set_elements ( Atmp, true );
373        }
374
375};
376UIREGISTER ( mprod );
377
378//! Product of independent epdfs. For dependent pdfs, use mprod.
379class eprod: public epdf {
380protected:
381        //! Components (epdfs)
382        Array<const epdf*> epdfs;
383        //! Array of indeces
384        Array<datalink*> dls;
385public:
386        eprod () : epdfs ( 0 ), dls ( 0 ) {};
387        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
388                epdfs = epdfs0;//.set_length ( epdfs0.length() );
389                dls.set_length ( epdfs.length() );
390
391                bool independent = true;
392                if ( named ) {
393                        for ( int i = 0; i < epdfs.length(); i++ ) {
394                                independent = rv.add ( epdfs ( i )->_rv() );
395                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
396                        }
397                        dim = rv._dsize();
398                } else {
399                        dim = 0;
400                        for ( int i = 0; i < epdfs.length(); i++ ) {
401                                dim += epdfs ( i )->dimension();
402                        }
403                }
404                //
405                int cumdim = 0;
406                int dimi = 0;
407                int i;
408                for ( i = 0; i < epdfs.length(); i++ ) {
409                        dls ( i ) = new datalink;
410                        if ( named ) {
411                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
412                        } else {
413                                dimi = epdfs ( i )->dimension();
414                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
415                                cumdim += dimi;
416                        }
417                }
418        }
419
420        vec mean() const {
421                vec tmp ( dim );
422                for ( int i = 0; i < epdfs.length(); i++ ) {
423                        vec pom = epdfs ( i )->mean();
424                        dls ( i )->pushup ( tmp, pom );
425                }
426                return tmp;
427        }
428        vec variance() const {
429                vec tmp ( dim ); //second moment
430                for ( int i = 0; i < epdfs.length(); i++ ) {
431                        vec pom = epdfs ( i )->mean();
432                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
433                }
434                return tmp - pow ( mean(), 2 );
435        }
436        vec sample() const {
437                vec tmp ( dim );
438                for ( int i = 0; i < epdfs.length(); i++ ) {
439                        vec pom = epdfs ( i )->sample();
440                        dls ( i )->pushup ( tmp, pom );
441                }
442                return tmp;
443        }
444        double evallog ( const vec &val ) const {
445                double tmp = 0;
446                for ( int i = 0; i < epdfs.length(); i++ ) {
447                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
448                }
449                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
450                return tmp;
451        }
452        //!access function
453        const epdf* operator () ( int i ) const {
454                it_assert_debug ( i < epdfs.length(), "wrong index" );
455                return epdfs ( i );
456        }
457
458        //!Destructor
459        ~eprod() {
460                for ( int i = 0; i < epdfs.length(); i++ ) {
461                        delete dls ( i );
462                }
463        }
464};
465
466
467/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
468
469*/
470class mmix : public mpdf {
471protected:
472        //! Component (mpdfs)
473        Array<shared_ptr<mpdf> > Coms;
474        //!weights of the components
475        vec w;
476        //! dummy epdfs
477        epdf dummy_epdf;
478public:
479        //!Default constructor
480        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
481
482        //! Set weights \c w and components \c R
483        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
484                //!\TODO check if all components are OK
485                Coms = Coms0;
486                w=w0;   
487
488                set_rv(Coms(0)->_rv());
489                dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
490                set_rvc(Coms(0)->_rvc());
491                dimc = Coms(0)->_rvc()._dsize();
492        }
493        double evallogcond (const vec &dt, const vec &cond) {
494                double ll=0.0;
495                for (int i=0;i<Coms.length();i++){
496                        ll+=Coms(i)->evallogcond(dt,cond);
497                }
498                return ll;
499        }
500
501        vec samplecond (const vec &cond);
502
503};
504
505}
506#endif //MX_H
Note: See TracBrowser for help on using the browser.