root/library/bdm/stat/emix.h @ 503

Revision 503, 12.5 kB (checked in by vbarta, 15 years ago)

extended mmix::set_parameters to also set dimensions

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41        //!Denominator in the form of epdf
42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
45        //!datalink between conditional and nom
46        datalink_m2e dl;
47        //! dummy epdf that stores only rv and dim
48        epdf iepdf;
49public:
50        //!Default constructor. By default, the given epdf is not copied!
51        //! It is assumed that this function will be used only temporarily.
52        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
53                // adjust rv and rvc
54                rvc = nom0->_rv().subt ( rv );
55                dimc = rvc._dsize();
56                set_ep ( iepdf );
57                iepdf.set_parameters ( rv._dsize() );
58                iepdf.set_rv ( rv );
59
60                //prepare data structures
61                if ( copy ) {
62                        it_error ( "todo" );
63                        destroynom = true;
64                } else {
65                        nom = nom0;
66                        destroynom = false;
67                }
68                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
69
70                // build denominator
71                den = nom->marginal ( rvc );
72                dl.set_connection ( rv, rvc, nom0->_rv() );
73        };
74        double evallogcond ( const vec &val, const vec &cond ) {
75                double tmp;
76                vec nom_val ( dimension() + dimc );
77                dl.pushup_cond ( nom_val, val, cond );
78                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
79                return tmp;
80        }
81        //! Object takes ownership of nom and will destroy it
82        void ownnom() {
83                destroynom = true;
84        }
85        //! Default destructor
86        ~mratio() {
87                delete den;
88                if ( destroynom ) {
89                        delete nom;
90                }
91        }
92};
93
94/*!
95* \brief Mixture of epdfs
96
97Density function:
98\f[
99f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
100\f]
101where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
102
103*/
104class emix : public epdf {
105protected:
106        //! weights of the components
107        vec w;
108        //! Component (epdfs)
109        Array<epdf*> Coms;
110        //!Flag if owning Coms
111        bool destroyComs;
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy = false );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        emix* marginal ( const RV &rv ) const;
167        mratio* condition ( const RV &rv ) const; //why not mratio!!
168
169//Access methods
170        //! returns a pointer to the internal mean value. Use with Care!
171        vec& _w() {
172                return w;
173        }
174        virtual ~emix() {
175                if ( destroyComs ) {
176                        for ( int i = 0; i < Coms.length(); i++ ) {
177                                delete Coms ( i );
178                        }
179                }
180        }
181        //! Auxiliary function for taking ownership of the Coms()
182        void ownComs() {
183                destroyComs = true;
184        }
185
186        //!access function
187        epdf* _Coms ( int i ) {
188                return Coms ( i );
189        }
190        void set_rv ( const RV &rv ) {
191                epdf::set_rv ( rv );
192                for ( int i = 0; i < Coms.length(); i++ ) {
193                        Coms ( i )->set_rv ( rv );
194                }
195        }
196};
197
198
199/*!
200* \brief Mixture of egiws
201
202*/
203class egiwmix : public egiw {
204protected:
205        //! weights of the components
206        vec w;
207        //! Component (epdfs)
208        Array<egiw*> Coms;
209        //!Flag if owning Coms
210        bool destroyComs;
211public:
212        //!Default constructor
213        egiwmix ( ) : egiw ( ) {};
214
215        //! Set weights \c w and components \c Coms
216        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
217        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
218
219        //!return expected value
220        vec mean() const;
221
222        //!return a sample from the density
223        vec sample() const;
224
225        //!return the expected variance
226        vec variance() const;
227
228        // TODO!!! Defined to follow ANSI and/or for future development
229        void mean_mat ( mat &M, mat&R ) const {};
230        double evallog_nn ( const vec &val ) const {
231                return 0;
232        };
233        double lognc () const {
234                return 0;
235        };
236        emix* marginal ( const RV &rv ) const;
237
238//Access methods
239        //! returns a pointer to the internal mean value. Use with Care!
240        vec& _w() {
241                return w;
242        }
243        virtual ~egiwmix() {
244                if ( destroyComs ) {
245                        for ( int i = 0; i < Coms.length(); i++ ) {
246                                delete Coms ( i );
247                        }
248                }
249        }
250        //! Auxiliary function for taking ownership of the Coms()
251        void ownComs() {
252                destroyComs = true;
253        }
254
255        //!access function
256        egiw* _Coms ( int i ) {
257                return Coms ( i );
258        }
259
260        void set_rv ( const RV &rv ) {
261                egiw::set_rv ( rv );
262                for ( int i = 0; i < Coms.length(); i++ ) {
263                        Coms ( i )->set_rv ( rv );
264                }
265        }
266
267        //! Approximation of a GiW mix by a single GiW pdf
268        egiw* approx();
269};
270
271/*! \brief Chain rule decomposition of epdf
272
273Probability density in the form of Chain-rule decomposition:
274\[
275f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
276\]
277Note that
278*/
279class mprod: public compositepdf, public mpdf {
280protected:
281        //! Data link for each mpdfs
282        Array<datalink_m2m*> dls;
283
284        //! dummy epdf used only as storage for RV and dim
285        epdf iepdf;
286
287public:
288        /*!\brief Constructor from list of mFacs,
289        */
290        mprod() : iepdf( ) { }
291        mprod ( Array<mpdf*> mFacs ) :
292                        iepdf ( ) {
293                set_elements ( mFacs );
294        }
295
296        void set_elements ( Array<mpdf*> mFacs , bool own = false ) {
297
298                compositepdf::set_elements ( mFacs, own );
299                dls.set_size ( mFacs.length() );
300
301                set_ep ( iepdf);
302                RV rv = getrv ( true );
303                set_rv ( rv );
304                iepdf.set_parameters ( rv._dsize() );
305                setrvc (_rv(), rvc );
306                // rv and rvc established = > we can link them with mpdfs
307                for ( int i = 0; i < mpdfs.length(); i++ ) {
308                        dls ( i ) = new datalink_m2m;
309                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
310                }
311
312        };
313
314        double evallogcond ( const vec &val, const vec &cond ) {
315                int i;
316                double res = 0.0;
317                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
318                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
319                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
320                                                }
321                                                // add logarithms
322                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
323                        res += mpdfs ( i )->evallogcond (
324                                   dls ( i )->pushdown ( val ),
325                                   dls ( i )->get_cond ( val, cond )
326                               );
327                }
328                return res;
329        }
330        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
331                vec tmp ( Dt.cols() );
332                for ( int i = 0; i < Dt.cols(); i++ ) {
333                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
334                }
335                return tmp;
336        };
337        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
338                vec tmp ( Dt.length() );
339                for ( int i = 0; i < Dt.length(); i++ ) {
340                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
341                }
342                return tmp;
343        };
344
345
346        //TODO smarter...
347        vec samplecond ( const vec &cond ) {
348                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
349                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
350                vec smpi;
351                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
352                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
353                        // generate contribution of this mpdf
354                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
355                        // copy contribution of this pdf into smp
356                        dls ( i )->pushup ( smp, smpi );
357                }
358                return smp;
359        }
360        mat samplecond ( const vec &cond,  int N ) {
361                mat Smp ( dimension(), N );
362                for ( int i = 0; i < N; i++ ) {
363                        Smp.set_col ( i, samplecond ( cond ) );
364                }
365                return Smp;
366        }
367
368        ~mprod() {};
369        //! Load from structure with elements:
370        //!  \code
371        //! { class='mprod';
372        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
373        //! }
374        //! \endcode
375        //!@}
376        void from_setting ( const Setting &set ) {
377                Array<mpdf*> Atmp; //temporary Array
378                UI::get ( Atmp, set, "mpdfs", UI::compulsory );
379                set_elements ( Atmp, true );
380        }
381
382};
383UIREGISTER ( mprod );
384
385//! Product of independent epdfs. For dependent pdfs, use mprod.
386class eprod: public epdf {
387protected:
388        //! Components (epdfs)
389        Array<const epdf*> epdfs;
390        //! Array of indeces
391        Array<datalink*> dls;
392public:
393        eprod () : epdfs ( 0 ), dls ( 0 ) {};
394        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
395                epdfs = epdfs0;//.set_length ( epdfs0.length() );
396                dls.set_length ( epdfs.length() );
397
398                bool independent = true;
399                if ( named ) {
400                        for ( int i = 0; i < epdfs.length(); i++ ) {
401                                independent = rv.add ( epdfs ( i )->_rv() );
402                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
403                        }
404                        dim = rv._dsize();
405                } else {
406                        dim = 0;
407                        for ( int i = 0; i < epdfs.length(); i++ ) {
408                                dim += epdfs ( i )->dimension();
409                        }
410                }
411                //
412                int cumdim = 0;
413                int dimi = 0;
414                int i;
415                for ( i = 0; i < epdfs.length(); i++ ) {
416                        dls ( i ) = new datalink;
417                        if ( named ) {
418                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
419                        } else {
420                                dimi = epdfs ( i )->dimension();
421                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
422                                cumdim += dimi;
423                        }
424                }
425        }
426
427        vec mean() const {
428                vec tmp ( dim );
429                for ( int i = 0; i < epdfs.length(); i++ ) {
430                        vec pom = epdfs ( i )->mean();
431                        dls ( i )->pushup ( tmp, pom );
432                }
433                return tmp;
434        }
435        vec variance() const {
436                vec tmp ( dim ); //second moment
437                for ( int i = 0; i < epdfs.length(); i++ ) {
438                        vec pom = epdfs ( i )->mean();
439                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
440                }
441                return tmp - pow ( mean(), 2 );
442        }
443        vec sample() const {
444                vec tmp ( dim );
445                for ( int i = 0; i < epdfs.length(); i++ ) {
446                        vec pom = epdfs ( i )->sample();
447                        dls ( i )->pushup ( tmp, pom );
448                }
449                return tmp;
450        }
451        double evallog ( const vec &val ) const {
452                double tmp = 0;
453                for ( int i = 0; i < epdfs.length(); i++ ) {
454                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
455                }
456                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
457                return tmp;
458        }
459        //!access function
460        const epdf* operator () ( int i ) const {
461                it_assert_debug ( i < epdfs.length(), "wrong index" );
462                return epdfs ( i );
463        }
464
465        //!Destructor
466        ~eprod() {
467                for ( int i = 0; i < epdfs.length(); i++ ) {
468                        delete dls ( i );
469                }
470        }
471};
472
473
474/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
475
476*/
477class mmix : public mpdf {
478protected:
479        //! Component (mpdfs)
480        Array<shared_ptr<mpdf> > Coms;
481        //!weights of the components
482        vec w;
483        //! dummy epdfs
484        epdf dummy_epdf;
485public:
486        //!Default constructor
487        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
488
489        //! Set weights \c w and components \c R
490        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
491                //!\TODO check if all components are OK
492                Coms = Coms0;
493                w=w0;   
494
495                set_rv(Coms(0)->_rv());
496                dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
497                set_rvc(Coms(0)->_rvc());
498                dimc = Coms(0)->_rvc()._dsize();
499        }
500        double evallogcond (const vec &dt, const vec &cond) {
501                double ll=0.0;
502                for (int i=0;i<Coms.length();i++){
503                        ll+=Coms(i)->evallogcond(dt,cond);
504                }
505                return ll;
506        }
507
508        vec samplecond (const vec &cond);
509
510};
511
512}
513#endif //MX_H
Note: See TracBrowser for help on using the browser.