root/library/bdm/stat/emix.h @ 515

Revision 515, 12.1 kB (checked in by vbarta, 15 years ago)

fixed previous fix

  • Property svn:eol-style set to native
RevLine 
[107]1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[394]13#ifndef EMIX_H
14#define EMIX_H
[107]15
[477]16#define LOG2  0.69314718055995
[333]17
[461]18#include "../shared_ptr.h"
[384]19#include "exp_family.h"
[107]20
[286]21namespace bdm {
[107]22
[182]23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
[211]35At present the only supported operation is evallogcond().
[182]36 */
37class mratio: public mpdf {
[192]38protected:
[182]39        //! Nominator in the form of mpdf
[192]40        const epdf* nom;
[504]41
[182]42        //!Denominator in the form of epdf
[504]43        shared_ptr<epdf> den;
44
[192]45        //!flag for destructor
46        bool destroynom;
[193]47        //!datalink between conditional and nom
48        datalink_m2e dl;
[487]49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
[192]51public:
52        //!Default constructor. By default, the given epdf is not copied!
[182]53        //! It is assumed that this function will be used only temporarily.
[487]54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
[286]55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
[487]58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
[477]61
[286]62                //prepare data structures
[477]63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
[286]72                // build denominator
[192]73                den = nom->marginal ( rvc );
[477]74                dl.set_connection ( rv, rvc, nom0->_rv() );
[192]75        };
[211]76        double evallogcond ( const vec &val, const vec &cond ) {
[214]77                double tmp;
[487]78                vec nom_val ( dimension() + dimc );
[477]79                dl.pushup_cond ( nom_val, val, cond );
[214]80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
[192]82        }
[182]83        //! Object takes ownership of nom and will destroy it
[477]84        void ownnom() {
85                destroynom = true;
86        }
[182]87        //! Default destructor
[477]88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
[182]93};
94
[107]95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
[145]105class emix : public epdf {
[162]106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
[504]110        Array<shared_ptr<epdf> > Coms;
111
[162]112public:
113        //!Default constructor
[286]114        emix ( ) : epdf ( ) {};
[182]115        //! Set weights \c w and components \c Coms
[224]116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
[504]117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
[107]118
[162]119        vec sample() const;
120        vec mean() const {
[477]121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
[162]126                return mu;
127        }
[229]128        vec variance() const {
129                //non-central moment
[286]130                vec mom2 = zeros ( dim );
[477]131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
[229]134                //central moment
[477]135                return mom2 - pow ( mean(), 2 );
[229]136        }
[211]137        double evallog ( const vec &val ) const {
[162]138                int i;
139                double sum = 0.0;
[477]140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[214]148                return tmp;
[162]149        };
[211]150        vec evallog_m ( const mat &Val ) const {
[477]151                vec x = zeros ( Val.cols() );
[192]152                for ( int i = 0; i < w.length(); i++ ) {
[477]153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
[182]154                }
[192]155                return log ( x );
[182]156        };
[214]157        //! Auxiliary function that returns pdflog for each component
[211]158        mat evallog_M ( const mat &Val ) const {
[192]159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
[211]161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
[189]162                }
163                return X;
164        };
[107]165
[504]166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        void marginal ( const RV &rv, emix &target ) const;
168        shared_ptr<mpdf> condition ( const RV &rv ) const;
[182]169
[107]170//Access methods
[162]171        //! returns a pointer to the internal mean value. Use with Care!
[477]172        vec& _w() {
173                return w;
174        }
[204]175
[193]176        //!access function
[504]177        shared_ptr<epdf> _Coms ( int i ) {
[477]178                return Coms ( i );
[286]179        }
[504]180
[477]181        void set_rv ( const RV &rv ) {
182                epdf::set_rv ( rv );
183                for ( int i = 0; i < Coms.length(); i++ ) {
184                        Coms ( i )->set_rv ( rv );
185                }
186        }
[107]187};
188
[333]189
190/*!
191* \brief Mixture of egiws
192
193*/
194class egiwmix : public egiw {
195protected:
196        //! weights of the components
197        vec w;
198        //! Component (epdfs)
199        Array<egiw*> Coms;
200        //!Flag if owning Coms
201        bool destroyComs;
202public:
203        //!Default constructor
204        egiwmix ( ) : egiw ( ) {};
205
206        //! Set weights \c w and components \c Coms
207        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
[477]208        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
[333]209
210        //!return expected value
211        vec mean() const;
212
213        //!return a sample from the density
214        vec sample() const;
215
216        //!return the expected variance
[477]217        vec variance() const;
[333]218
219        // TODO!!! Defined to follow ANSI and/or for future development
220        void mean_mat ( mat &M, mat&R ) const {};
[477]221        double evallog_nn ( const vec &val ) const {
222                return 0;
223        };
224        double lognc () const {
225                return 0;
[504]226        }
[333]227
[504]228        shared_ptr<epdf> marginal ( const RV &rv ) const;
229        void marginal ( const RV &rv, emix &target ) const;
230
[333]231//Access methods
232        //! returns a pointer to the internal mean value. Use with Care!
[477]233        vec& _w() {
234                return w;
235        }
236        virtual ~egiwmix() {
237                if ( destroyComs ) {
238                        for ( int i = 0; i < Coms.length(); i++ ) {
239                                delete Coms ( i );
240                        }
241                }
242        }
[333]243        //! Auxiliary function for taking ownership of the Coms()
[477]244        void ownComs() {
245                destroyComs = true;
246        }
[333]247
248        //!access function
[477]249        egiw* _Coms ( int i ) {
250                return Coms ( i );
251        }
[333]252
[477]253        void set_rv ( const RV &rv ) {
254                egiw::set_rv ( rv );
255                for ( int i = 0; i < Coms.length(); i++ ) {
256                        Coms ( i )->set_rv ( rv );
257                }
[333]258        }
259
260        //! Approximation of a GiW mix by a single GiW pdf
261        egiw* approx();
262};
263
[115]264/*! \brief Chain rule decomposition of epdf
265
[145]266Probability density in the form of Chain-rule decomposition:
267\[
268f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
269\]
270Note that
[115]271*/
[507]272class mprod: public mpdf {
273private:
274        Array<shared_ptr<mpdf> > mpdfs;
275
[162]276protected:
[192]277        //! Data link for each mpdfs
278        Array<datalink_m2m*> dls;
[461]279
[487]280        //! dummy epdf used only as storage for RV and dim
281        epdf iepdf;
[461]282
[162]283public:
[507]284        //! \brief Default constructor
285        mprod() { }
286
287        /*!\brief Constructor from list of mFacs
[165]288        */
[507]289        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
[477]290                set_elements ( mFacs );
[461]291        }
292
[507]293        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
[477]294
[211]295        double evallogcond ( const vec &val, const vec &cond ) {
[162]296                int i;
[270]297                double res = 0.0;
[477]298                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
[193]299                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
300                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
301                                                }
302                                                // add logarithms
[270]303                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
304                        res += mpdfs ( i )->evallogcond (
305                                   dls ( i )->pushdown ( val ),
[204]306                                   dls ( i )->get_cond ( val, cond )
307                               );
[145]308                }
[193]309                return res;
[162]310        }
[477]311        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
312                vec tmp ( Dt.cols() );
313                for ( int i = 0; i < Dt.cols(); i++ ) {
314                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
[395]315                }
316                return tmp;
317        };
[477]318        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
319                vec tmp ( Dt.length() );
320                for ( int i = 0; i < Dt.length(); i++ ) {
321                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
[395]322                }
[477]323                return tmp;
[395]324        };
325
326
[270]327        //TODO smarter...
328        vec samplecond ( const vec &cond ) {
[477]329                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
[487]330                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
[165]331                vec smpi;
[192]332                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
[477]333                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
[487]334                        // generate contribution of this mpdf
335                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
[162]336                        // copy contribution of this pdf into smp
[270]337                        dls ( i )->pushup ( smp, smpi );
[145]338                }
[162]339                return smp;
340        }
[270]341        mat samplecond ( const vec &cond,  int N ) {
[477]342                mat Smp ( dimension(), N );
343                for ( int i = 0; i < N; i++ ) {
344                        Smp.set_col ( i, samplecond ( cond ) );
345                }
[162]346                return Smp;
347        }
348
[395]349        //! Load from structure with elements:
350        //!  \code
351        //! { class='mprod';
352        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
353        //! }
354        //! \endcode
355        //!@}
[477]356        void from_setting ( const Setting &set ) {
[507]357                Array<mpdf*> atmp; //temporary Array
358                UI::get ( atmp, set, "mpdfs", UI::compulsory );
359               
360                Array<shared_ptr<mpdf> > btmp ( atmp.length() );
361                for (int i = 0; i < atmp.length(); ++i) {
362                        btmp ( i ) = shared_ptr<mpdf> ( atmp ( i ) );
363                }
364
365                set_elements ( btmp );
[395]366        }
[477]367
[107]368};
[477]369UIREGISTER ( mprod );
[107]370
[168]371//! Product of independent epdfs. For dependent pdfs, use mprod.
372class eprod: public epdf {
373protected:
374        //! Components (epdfs)
[170]375        Array<const epdf*> epdfs;
[168]376        //! Array of indeces
[270]377        Array<datalink*> dls;
[168]378public:
[477]379        eprod () : epdfs ( 0 ), dls ( 0 ) {};
380        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
381                epdfs = epdfs0;//.set_length ( epdfs0.length() );
[286]382                dls.set_length ( epdfs.length() );
383
[477]384                bool independent = true;
[286]385                if ( named ) {
[477]386                        for ( int i = 0; i < epdfs.length(); i++ ) {
387                                independent = rv.add ( epdfs ( i )->_rv() );
388                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
[286]389                        }
[477]390                        dim = rv._dsize();
391                } else {
392                        dim = 0;
393                        for ( int i = 0; i < epdfs.length(); i++ ) {
394                                dim += epdfs ( i )->dimension();
[286]395                        }
[204]396                }
[286]397                //
[477]398                int cumdim = 0;
399                int dimi = 0;
[286]400                int i;
[477]401                for ( i = 0; i < epdfs.length(); i++ ) {
[286]402                        dls ( i ) = new datalink;
[477]403                        if ( named ) {
404                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
405                        } else {
[286]406                                dimi = epdfs ( i )->dimension();
[477]407                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
408                                cumdim += dimi;
[286]409                        }
410                }
[168]411        }
412
413        vec mean() const {
[270]414                vec tmp ( dim );
[477]415                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]416                        vec pom = epdfs ( i )->mean();
[270]417                        dls ( i )->pushup ( tmp, pom );
[168]418                }
419                return tmp;
420        }
[229]421        vec variance() const {
[270]422                vec tmp ( dim ); //second moment
[477]423                for ( int i = 0; i < epdfs.length(); i++ ) {
[229]424                        vec pom = epdfs ( i )->mean();
[477]425                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
[229]426                }
[477]427                return tmp - pow ( mean(), 2 );
[229]428        }
[168]429        vec sample() const {
[270]430                vec tmp ( dim );
[477]431                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]432                        vec pom = epdfs ( i )->sample();
[270]433                        dls ( i )->pushup ( tmp, pom );
[168]434                }
435                return tmp;
436        }
[211]437        double evallog ( const vec &val ) const {
[477]438                double tmp = 0;
439                for ( int i = 0; i < epdfs.length(); i++ ) {
440                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
[168]441                }
[477]442                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[168]443                return tmp;
444        }
[170]445        //!access function
[477]446        const epdf* operator () ( int i ) const {
447                it_assert_debug ( i < epdfs.length(), "wrong index" );
448                return epdfs ( i );
449        }
[204]450
[193]451        //!Destructor
[477]452        ~eprod() {
453                for ( int i = 0; i < epdfs.length(); i++ ) {
454                        delete dls ( i );
455                }
456        }
[168]457};
458
459
[488]460/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
[124]461
462*/
[488]463class mmix : public mpdf {
464protected:
465        //! Component (mpdfs)
466        Array<shared_ptr<mpdf> > Coms;
467        //!weights of the components
468        vec w;
469        //! dummy epdfs
470        epdf dummy_epdf;
471public:
472        //!Default constructor
473        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
[461]474
[488]475        //! Set weights \c w and components \c R
476        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
477                //!\TODO check if all components are OK
478                Coms = Coms0;
479                w=w0;   
[503]480
[513]481                if (Coms0.length()>0){
482                        set_rv(Coms(0)->_rv());
[515]483                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
[513]484                        set_rvc(Coms(0)->_rvc());
485                        dimc = rvc._dsize();
486                }
[488]487        }
488        double evallogcond (const vec &dt, const vec &cond) {
489                double ll=0.0;
490                for (int i=0;i<Coms.length();i++){
491                        ll+=Coms(i)->evallogcond(dt,cond);
492                }
493                return ll;
494        }
495
496        vec samplecond (const vec &cond);
497
498};
499
[254]500}
[107]501#endif //MX_H
Note: See TracBrowser for help on using the browser.