root/library/bdm/stat/emix.h @ 503

Revision 503, 12.5 kB (checked in by vbarta, 15 years ago)

extended mmix::set_parameters to also set dimensions

  • Property svn:eol-style set to native
RevLine 
[107]1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[394]13#ifndef EMIX_H
14#define EMIX_H
[107]15
[477]16#define LOG2  0.69314718055995
[333]17
[461]18#include "../shared_ptr.h"
[384]19#include "exp_family.h"
[107]20
[286]21namespace bdm {
[107]22
[182]23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
[211]35At present the only supported operation is evallogcond().
[182]36 */
37class mratio: public mpdf {
[192]38protected:
[182]39        //! Nominator in the form of mpdf
[192]40        const epdf* nom;
[182]41        //!Denominator in the form of epdf
[192]42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
[193]45        //!datalink between conditional and nom
46        datalink_m2e dl;
[487]47        //! dummy epdf that stores only rv and dim
48        epdf iepdf;
[192]49public:
50        //!Default constructor. By default, the given epdf is not copied!
[182]51        //! It is assumed that this function will be used only temporarily.
[487]52        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
[286]53                // adjust rv and rvc
54                rvc = nom0->_rv().subt ( rv );
55                dimc = rvc._dsize();
[487]56                set_ep ( iepdf );
57                iepdf.set_parameters ( rv._dsize() );
58                iepdf.set_rv ( rv );
[477]59
[286]60                //prepare data structures
[477]61                if ( copy ) {
62                        it_error ( "todo" );
63                        destroynom = true;
64                } else {
65                        nom = nom0;
66                        destroynom = false;
67                }
68                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
69
[286]70                // build denominator
[192]71                den = nom->marginal ( rvc );
[477]72                dl.set_connection ( rv, rvc, nom0->_rv() );
[192]73        };
[211]74        double evallogcond ( const vec &val, const vec &cond ) {
[214]75                double tmp;
[487]76                vec nom_val ( dimension() + dimc );
[477]77                dl.pushup_cond ( nom_val, val, cond );
[214]78                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
79                return tmp;
[192]80        }
[182]81        //! Object takes ownership of nom and will destroy it
[477]82        void ownnom() {
83                destroynom = true;
84        }
[182]85        //! Default destructor
[477]86        ~mratio() {
87                delete den;
88                if ( destroynom ) {
89                        delete nom;
90                }
91        }
[182]92};
93
[107]94/*!
95* \brief Mixture of epdfs
96
97Density function:
98\f[
99f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
100\f]
101where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
102
103*/
[145]104class emix : public epdf {
[162]105protected:
106        //! weights of the components
107        vec w;
108        //! Component (epdfs)
109        Array<epdf*> Coms;
[178]110        //!Flag if owning Coms
111        bool destroyComs;
[162]112public:
113        //!Default constructor
[286]114        emix ( ) : epdf ( ) {};
[182]115        //! Set weights \c w and components \c Coms
[224]116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
[477]117        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy = false );
[107]118
[162]119        vec sample() const;
120        vec mean() const {
[477]121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
[162]126                return mu;
127        }
[229]128        vec variance() const {
129                //non-central moment
[286]130                vec mom2 = zeros ( dim );
[477]131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
[229]134                //central moment
[477]135                return mom2 - pow ( mean(), 2 );
[229]136        }
[211]137        double evallog ( const vec &val ) const {
[162]138                int i;
139                double sum = 0.0;
[477]140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[214]148                return tmp;
[162]149        };
[211]150        vec evallog_m ( const mat &Val ) const {
[477]151                vec x = zeros ( Val.cols() );
[192]152                for ( int i = 0; i < w.length(); i++ ) {
[477]153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
[182]154                }
[192]155                return log ( x );
[182]156        };
[214]157        //! Auxiliary function that returns pdflog for each component
[211]158        mat evallog_M ( const mat &Val ) const {
[192]159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
[211]161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
[189]162                }
163                return X;
164        };
[107]165
[182]166        emix* marginal ( const RV &rv ) const;
167        mratio* condition ( const RV &rv ) const; //why not mratio!!
168
[107]169//Access methods
[162]170        //! returns a pointer to the internal mean value. Use with Care!
[477]171        vec& _w() {
172                return w;
173        }
174        virtual ~emix() {
175                if ( destroyComs ) {
176                        for ( int i = 0; i < Coms.length(); i++ ) {
177                                delete Coms ( i );
178                        }
179                }
180        }
[178]181        //! Auxiliary function for taking ownership of the Coms()
[477]182        void ownComs() {
183                destroyComs = true;
184        }
[204]185
[193]186        //!access function
[477]187        epdf* _Coms ( int i ) {
188                return Coms ( i );
[286]189        }
[477]190        void set_rv ( const RV &rv ) {
191                epdf::set_rv ( rv );
192                for ( int i = 0; i < Coms.length(); i++ ) {
193                        Coms ( i )->set_rv ( rv );
194                }
195        }
[107]196};
197
[333]198
199/*!
200* \brief Mixture of egiws
201
202*/
203class egiwmix : public egiw {
204protected:
205        //! weights of the components
206        vec w;
207        //! Component (epdfs)
208        Array<egiw*> Coms;
209        //!Flag if owning Coms
210        bool destroyComs;
211public:
212        //!Default constructor
213        egiwmix ( ) : egiw ( ) {};
214
215        //! Set weights \c w and components \c Coms
216        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
[477]217        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
[333]218
219        //!return expected value
220        vec mean() const;
221
222        //!return a sample from the density
223        vec sample() const;
224
225        //!return the expected variance
[477]226        vec variance() const;
[333]227
228        // TODO!!! Defined to follow ANSI and/or for future development
229        void mean_mat ( mat &M, mat&R ) const {};
[477]230        double evallog_nn ( const vec &val ) const {
231                return 0;
232        };
233        double lognc () const {
234                return 0;
235        };
[333]236        emix* marginal ( const RV &rv ) const;
237
238//Access methods
239        //! returns a pointer to the internal mean value. Use with Care!
[477]240        vec& _w() {
241                return w;
242        }
243        virtual ~egiwmix() {
244                if ( destroyComs ) {
245                        for ( int i = 0; i < Coms.length(); i++ ) {
246                                delete Coms ( i );
247                        }
248                }
249        }
[333]250        //! Auxiliary function for taking ownership of the Coms()
[477]251        void ownComs() {
252                destroyComs = true;
253        }
[333]254
255        //!access function
[477]256        egiw* _Coms ( int i ) {
257                return Coms ( i );
258        }
[333]259
[477]260        void set_rv ( const RV &rv ) {
261                egiw::set_rv ( rv );
262                for ( int i = 0; i < Coms.length(); i++ ) {
263                        Coms ( i )->set_rv ( rv );
264                }
[333]265        }
266
267        //! Approximation of a GiW mix by a single GiW pdf
268        egiw* approx();
269};
270
[115]271/*! \brief Chain rule decomposition of epdf
272
[145]273Probability density in the form of Chain-rule decomposition:
274\[
275f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
276\]
277Note that
[115]278*/
[175]279class mprod: public compositepdf, public mpdf {
[162]280protected:
[192]281        //! Data link for each mpdfs
282        Array<datalink_m2m*> dls;
[461]283
[487]284        //! dummy epdf used only as storage for RV and dim
285        epdf iepdf;
[461]286
[162]287public:
[168]288        /*!\brief Constructor from list of mFacs,
[165]289        */
[487]290        mprod() : iepdf( ) { }
[477]291        mprod ( Array<mpdf*> mFacs ) :
[487]292                        iepdf ( ) {
[477]293                set_elements ( mFacs );
[461]294        }
295
[477]296        void set_elements ( Array<mpdf*> mFacs , bool own = false ) {
297
298                compositepdf::set_elements ( mFacs, own );
299                dls.set_size ( mFacs.length() );
300
[487]301                set_ep ( iepdf);
[477]302                RV rv = getrv ( true );
[395]303                set_rv ( rv );
[487]304                iepdf.set_parameters ( rv._dsize() );
305                setrvc (_rv(), rvc );
[192]306                // rv and rvc established = > we can link them with mpdfs
[477]307                for ( int i = 0; i < mpdfs.length(); i++ ) {
[286]308                        dls ( i ) = new datalink_m2m;
[477]309                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
[181]310                }
311
[175]312        };
[124]313
[211]314        double evallogcond ( const vec &val, const vec &cond ) {
[162]315                int i;
[270]316                double res = 0.0;
[477]317                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
[193]318                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
319                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
320                                                }
321                                                // add logarithms
[270]322                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
323                        res += mpdfs ( i )->evallogcond (
324                                   dls ( i )->pushdown ( val ),
[204]325                                   dls ( i )->get_cond ( val, cond )
326                               );
[145]327                }
[193]328                return res;
[162]329        }
[477]330        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
331                vec tmp ( Dt.cols() );
332                for ( int i = 0; i < Dt.cols(); i++ ) {
333                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
[395]334                }
335                return tmp;
336        };
[477]337        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
338                vec tmp ( Dt.length() );
339                for ( int i = 0; i < Dt.length(); i++ ) {
340                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
[395]341                }
[477]342                return tmp;
[395]343        };
344
345
[270]346        //TODO smarter...
347        vec samplecond ( const vec &cond ) {
[477]348                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
[487]349                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
[165]350                vec smpi;
[192]351                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
[477]352                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
[487]353                        // generate contribution of this mpdf
354                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
[162]355                        // copy contribution of this pdf into smp
[270]356                        dls ( i )->pushup ( smp, smpi );
[145]357                }
[162]358                return smp;
359        }
[270]360        mat samplecond ( const vec &cond,  int N ) {
[477]361                mat Smp ( dimension(), N );
362                for ( int i = 0; i < N; i++ ) {
363                        Smp.set_col ( i, samplecond ( cond ) );
364                }
[162]365                return Smp;
366        }
367
368        ~mprod() {};
[395]369        //! Load from structure with elements:
370        //!  \code
371        //! { class='mprod';
372        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
373        //! }
374        //! \endcode
375        //!@}
[477]376        void from_setting ( const Setting &set ) {
[395]377                Array<mpdf*> Atmp; //temporary Array
[477]378                UI::get ( Atmp, set, "mpdfs", UI::compulsory );
379                set_elements ( Atmp, true );
[395]380        }
[477]381
[107]382};
[477]383UIREGISTER ( mprod );
[107]384
[168]385//! Product of independent epdfs. For dependent pdfs, use mprod.
386class eprod: public epdf {
387protected:
388        //! Components (epdfs)
[170]389        Array<const epdf*> epdfs;
[168]390        //! Array of indeces
[270]391        Array<datalink*> dls;
[168]392public:
[477]393        eprod () : epdfs ( 0 ), dls ( 0 ) {};
394        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
395                epdfs = epdfs0;//.set_length ( epdfs0.length() );
[286]396                dls.set_length ( epdfs.length() );
397
[477]398                bool independent = true;
[286]399                if ( named ) {
[477]400                        for ( int i = 0; i < epdfs.length(); i++ ) {
401                                independent = rv.add ( epdfs ( i )->_rv() );
402                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
[286]403                        }
[477]404                        dim = rv._dsize();
405                } else {
406                        dim = 0;
407                        for ( int i = 0; i < epdfs.length(); i++ ) {
408                                dim += epdfs ( i )->dimension();
[286]409                        }
[204]410                }
[286]411                //
[477]412                int cumdim = 0;
413                int dimi = 0;
[286]414                int i;
[477]415                for ( i = 0; i < epdfs.length(); i++ ) {
[286]416                        dls ( i ) = new datalink;
[477]417                        if ( named ) {
418                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
419                        } else {
[286]420                                dimi = epdfs ( i )->dimension();
[477]421                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
422                                cumdim += dimi;
[286]423                        }
424                }
[168]425        }
426
427        vec mean() const {
[270]428                vec tmp ( dim );
[477]429                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]430                        vec pom = epdfs ( i )->mean();
[270]431                        dls ( i )->pushup ( tmp, pom );
[168]432                }
433                return tmp;
434        }
[229]435        vec variance() const {
[270]436                vec tmp ( dim ); //second moment
[477]437                for ( int i = 0; i < epdfs.length(); i++ ) {
[229]438                        vec pom = epdfs ( i )->mean();
[477]439                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
[229]440                }
[477]441                return tmp - pow ( mean(), 2 );
[229]442        }
[168]443        vec sample() const {
[270]444                vec tmp ( dim );
[477]445                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]446                        vec pom = epdfs ( i )->sample();
[270]447                        dls ( i )->pushup ( tmp, pom );
[168]448                }
449                return tmp;
450        }
[211]451        double evallog ( const vec &val ) const {
[477]452                double tmp = 0;
453                for ( int i = 0; i < epdfs.length(); i++ ) {
454                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
[168]455                }
[477]456                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[168]457                return tmp;
458        }
[170]459        //!access function
[477]460        const epdf* operator () ( int i ) const {
461                it_assert_debug ( i < epdfs.length(), "wrong index" );
462                return epdfs ( i );
463        }
[204]464
[193]465        //!Destructor
[477]466        ~eprod() {
467                for ( int i = 0; i < epdfs.length(); i++ ) {
468                        delete dls ( i );
469                }
470        }
[168]471};
472
473
[488]474/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
[124]475
476*/
[488]477class mmix : public mpdf {
478protected:
479        //! Component (mpdfs)
480        Array<shared_ptr<mpdf> > Coms;
481        //!weights of the components
482        vec w;
483        //! dummy epdfs
484        epdf dummy_epdf;
485public:
486        //!Default constructor
487        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
[461]488
[488]489        //! Set weights \c w and components \c R
490        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
491                //!\TODO check if all components are OK
492                Coms = Coms0;
493                w=w0;   
[503]494
[488]495                set_rv(Coms(0)->_rv());
[503]496                dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
[488]497                set_rvc(Coms(0)->_rvc());
[503]498                dimc = Coms(0)->_rvc()._dsize();
[488]499        }
500        double evallogcond (const vec &dt, const vec &cond) {
501                double ll=0.0;
502                for (int i=0;i<Coms.length();i++){
503                        ll+=Coms(i)->evallogcond(dt,cond);
504                }
505                return ll;
506        }
507
508        vec samplecond (const vec &cond);
509
510};
511
[254]512}
[107]513#endif //MX_H
Note: See TracBrowser for help on using the browser.