root/library/bdm/stat/emix.h @ 507

Revision 507, 12.1 kB (checked in by vbarta, 15 years ago)

removed class compositepdf; keeping mpdfs of mprod and merger_base in shared pointers

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93};
94
95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
105class emix : public epdf {
106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
110        Array<shared_ptr<epdf> > Coms;
111
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        void marginal ( const RV &rv, emix &target ) const;
168        shared_ptr<mpdf> condition ( const RV &rv ) const;
169
170//Access methods
171        //! returns a pointer to the internal mean value. Use with Care!
172        vec& _w() {
173                return w;
174        }
175
176        //!access function
177        shared_ptr<epdf> _Coms ( int i ) {
178                return Coms ( i );
179        }
180
181        void set_rv ( const RV &rv ) {
182                epdf::set_rv ( rv );
183                for ( int i = 0; i < Coms.length(); i++ ) {
184                        Coms ( i )->set_rv ( rv );
185                }
186        }
187};
188
189
190/*!
191* \brief Mixture of egiws
192
193*/
194class egiwmix : public egiw {
195protected:
196        //! weights of the components
197        vec w;
198        //! Component (epdfs)
199        Array<egiw*> Coms;
200        //!Flag if owning Coms
201        bool destroyComs;
202public:
203        //!Default constructor
204        egiwmix ( ) : egiw ( ) {};
205
206        //! Set weights \c w and components \c Coms
207        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
208        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
209
210        //!return expected value
211        vec mean() const;
212
213        //!return a sample from the density
214        vec sample() const;
215
216        //!return the expected variance
217        vec variance() const;
218
219        // TODO!!! Defined to follow ANSI and/or for future development
220        void mean_mat ( mat &M, mat&R ) const {};
221        double evallog_nn ( const vec &val ) const {
222                return 0;
223        };
224        double lognc () const {
225                return 0;
226        }
227
228        shared_ptr<epdf> marginal ( const RV &rv ) const;
229        void marginal ( const RV &rv, emix &target ) const;
230
231//Access methods
232        //! returns a pointer to the internal mean value. Use with Care!
233        vec& _w() {
234                return w;
235        }
236        virtual ~egiwmix() {
237                if ( destroyComs ) {
238                        for ( int i = 0; i < Coms.length(); i++ ) {
239                                delete Coms ( i );
240                        }
241                }
242        }
243        //! Auxiliary function for taking ownership of the Coms()
244        void ownComs() {
245                destroyComs = true;
246        }
247
248        //!access function
249        egiw* _Coms ( int i ) {
250                return Coms ( i );
251        }
252
253        void set_rv ( const RV &rv ) {
254                egiw::set_rv ( rv );
255                for ( int i = 0; i < Coms.length(); i++ ) {
256                        Coms ( i )->set_rv ( rv );
257                }
258        }
259
260        //! Approximation of a GiW mix by a single GiW pdf
261        egiw* approx();
262};
263
264/*! \brief Chain rule decomposition of epdf
265
266Probability density in the form of Chain-rule decomposition:
267\[
268f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
269\]
270Note that
271*/
272class mprod: public mpdf {
273private:
274        Array<shared_ptr<mpdf> > mpdfs;
275
276protected:
277        //! Data link for each mpdfs
278        Array<datalink_m2m*> dls;
279
280        //! dummy epdf used only as storage for RV and dim
281        epdf iepdf;
282
283public:
284        //! \brief Default constructor
285        mprod() { }
286
287        /*!\brief Constructor from list of mFacs
288        */
289        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
290                set_elements ( mFacs );
291        }
292
293        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
294
295        double evallogcond ( const vec &val, const vec &cond ) {
296                int i;
297                double res = 0.0;
298                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
299                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
300                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
301                                                }
302                                                // add logarithms
303                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
304                        res += mpdfs ( i )->evallogcond (
305                                   dls ( i )->pushdown ( val ),
306                                   dls ( i )->get_cond ( val, cond )
307                               );
308                }
309                return res;
310        }
311        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
312                vec tmp ( Dt.cols() );
313                for ( int i = 0; i < Dt.cols(); i++ ) {
314                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
315                }
316                return tmp;
317        };
318        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
319                vec tmp ( Dt.length() );
320                for ( int i = 0; i < Dt.length(); i++ ) {
321                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
322                }
323                return tmp;
324        };
325
326
327        //TODO smarter...
328        vec samplecond ( const vec &cond ) {
329                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
330                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
331                vec smpi;
332                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
333                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
334                        // generate contribution of this mpdf
335                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
336                        // copy contribution of this pdf into smp
337                        dls ( i )->pushup ( smp, smpi );
338                }
339                return smp;
340        }
341        mat samplecond ( const vec &cond,  int N ) {
342                mat Smp ( dimension(), N );
343                for ( int i = 0; i < N; i++ ) {
344                        Smp.set_col ( i, samplecond ( cond ) );
345                }
346                return Smp;
347        }
348
349        //! Load from structure with elements:
350        //!  \code
351        //! { class='mprod';
352        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
353        //! }
354        //! \endcode
355        //!@}
356        void from_setting ( const Setting &set ) {
357                Array<mpdf*> atmp; //temporary Array
358                UI::get ( atmp, set, "mpdfs", UI::compulsory );
359               
360                Array<shared_ptr<mpdf> > btmp ( atmp.length() );
361                for (int i = 0; i < atmp.length(); ++i) {
362                        btmp ( i ) = shared_ptr<mpdf> ( atmp ( i ) );
363                }
364
365                set_elements ( btmp );
366        }
367
368};
369UIREGISTER ( mprod );
370
371//! Product of independent epdfs. For dependent pdfs, use mprod.
372class eprod: public epdf {
373protected:
374        //! Components (epdfs)
375        Array<const epdf*> epdfs;
376        //! Array of indeces
377        Array<datalink*> dls;
378public:
379        eprod () : epdfs ( 0 ), dls ( 0 ) {};
380        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
381                epdfs = epdfs0;//.set_length ( epdfs0.length() );
382                dls.set_length ( epdfs.length() );
383
384                bool independent = true;
385                if ( named ) {
386                        for ( int i = 0; i < epdfs.length(); i++ ) {
387                                independent = rv.add ( epdfs ( i )->_rv() );
388                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
389                        }
390                        dim = rv._dsize();
391                } else {
392                        dim = 0;
393                        for ( int i = 0; i < epdfs.length(); i++ ) {
394                                dim += epdfs ( i )->dimension();
395                        }
396                }
397                //
398                int cumdim = 0;
399                int dimi = 0;
400                int i;
401                for ( i = 0; i < epdfs.length(); i++ ) {
402                        dls ( i ) = new datalink;
403                        if ( named ) {
404                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
405                        } else {
406                                dimi = epdfs ( i )->dimension();
407                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
408                                cumdim += dimi;
409                        }
410                }
411        }
412
413        vec mean() const {
414                vec tmp ( dim );
415                for ( int i = 0; i < epdfs.length(); i++ ) {
416                        vec pom = epdfs ( i )->mean();
417                        dls ( i )->pushup ( tmp, pom );
418                }
419                return tmp;
420        }
421        vec variance() const {
422                vec tmp ( dim ); //second moment
423                for ( int i = 0; i < epdfs.length(); i++ ) {
424                        vec pom = epdfs ( i )->mean();
425                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
426                }
427                return tmp - pow ( mean(), 2 );
428        }
429        vec sample() const {
430                vec tmp ( dim );
431                for ( int i = 0; i < epdfs.length(); i++ ) {
432                        vec pom = epdfs ( i )->sample();
433                        dls ( i )->pushup ( tmp, pom );
434                }
435                return tmp;
436        }
437        double evallog ( const vec &val ) const {
438                double tmp = 0;
439                for ( int i = 0; i < epdfs.length(); i++ ) {
440                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
441                }
442                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
443                return tmp;
444        }
445        //!access function
446        const epdf* operator () ( int i ) const {
447                it_assert_debug ( i < epdfs.length(), "wrong index" );
448                return epdfs ( i );
449        }
450
451        //!Destructor
452        ~eprod() {
453                for ( int i = 0; i < epdfs.length(); i++ ) {
454                        delete dls ( i );
455                }
456        }
457};
458
459
460/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
461
462*/
463class mmix : public mpdf {
464protected:
465        //! Component (mpdfs)
466        Array<shared_ptr<mpdf> > Coms;
467        //!weights of the components
468        vec w;
469        //! dummy epdfs
470        epdf dummy_epdf;
471public:
472        //!Default constructor
473        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
474
475        //! Set weights \c w and components \c R
476        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
477                //!\TODO check if all components are OK
478                Coms = Coms0;
479                w=w0;   
480
481                set_rv(Coms(0)->_rv());
482                dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
483                set_rvc(Coms(0)->_rvc());
484                dimc = Coms(0)->_rvc()._dsize();
485        }
486        double evallogcond (const vec &dt, const vec &cond) {
487                double ll=0.0;
488                for (int i=0;i<Coms.length();i++){
489                        ll+=Coms(i)->evallogcond(dt,cond);
490                }
491                return ll;
492        }
493
494        vec samplecond (const vec &cond);
495
496};
497
498}
499#endif //MX_H
Note: See TracBrowser for help on using the browser.