root/library/bdm/stat/emix.h @ 737

Revision 737, 13.3 kB (checked in by mido, 15 years ago)

ASTYLER RUN OVER THE WHOLE LIBRARY, JUPEE

  • Property svn:eol-style set to native
RevLine 
[107]1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[394]13#ifndef EMIX_H
14#define EMIX_H
[107]15
[477]16#define LOG2  0.69314718055995
[333]17
[461]18#include "../shared_ptr.h"
[384]19#include "exp_family.h"
[107]20
[286]21namespace bdm {
[107]22
[182]23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
[211]35At present the only supported operation is evallogcond().
[182]36 */
[693]37class mratio: public pdf {
[192]38protected:
[693]39        //! Nominator in the form of pdf
[192]40        const epdf* nom;
[504]41
[182]42        //!Denominator in the form of epdf
[504]43        shared_ptr<epdf> den;
44
[192]45        //!flag for destructor
46        bool destroynom;
[193]47        //!datalink between conditional and nom
48        datalink_m2e dl;
[192]49public:
50        //!Default constructor. By default, the given epdf is not copied!
[182]51        //! It is assumed that this function will be used only temporarily.
[693]52        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
[286]53                // adjust rv and rvc
[675]54
[737]55                set_rv ( rv ); // TODO co kdyby tohle samo uz nastavovalo dimension?!?!
[675]56                dim = rv._dsize();
57
[286]58                rvc = nom0->_rv().subt ( rv );
59                dimc = rvc._dsize();
[737]60
[286]61                //prepare data structures
[477]62                if ( copy ) {
[565]63                        bdm_error ( "todo" );
64                        // destroynom = true;
[477]65                } else {
66                        nom = nom0;
67                        destroynom = false;
68                }
[565]69                bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
[477]70
[286]71                // build denominator
[192]72                den = nom->marginal ( rvc );
[477]73                dl.set_connection ( rv, rvc, nom0->_rv() );
[192]74        };
[675]75
[211]76        double evallogcond ( const vec &val, const vec &cond ) {
[214]77                double tmp;
[487]78                vec nom_val ( dimension() + dimc );
[477]79                dl.pushup_cond ( nom_val, val, cond );
[214]80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
[192]82        }
[182]83        //! Object takes ownership of nom and will destroy it
[477]84        void ownnom() {
85                destroynom = true;
86        }
[182]87        //! Default destructor
[477]88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
[550]93
94private:
95        // not implemented
96        mratio ( const mratio & );
[737]97        mratio &operator= ( const mratio & );
[182]98};
99
[107]100/*!
101* \brief Mixture of epdfs
102
103Density function:
104\f[
105f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
106\f]
107where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
108
109*/
[145]110class emix : public epdf {
[162]111protected:
112        //! weights of the components
113        vec w;
[559]114
[162]115        //! Component (epdfs)
[504]116        Array<shared_ptr<epdf> > Coms;
117
[162]118public:
[559]119        //! Default constructor
120        emix ( ) : epdf ( ) { }
121
[737]122        /*!
123          \brief Set weights \c w and components \c Coms
[559]124
[737]125          Shared pointers in Coms are kept inside this instance and
126          shouldn't be modified after being passed to this method.
127        */
[504]128        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
[107]129
[162]130        vec sample() const;
131        vec mean() const {
[477]132                int i;
133                vec mu = zeros ( dim );
134                for ( i = 0; i < w.length(); i++ ) {
135                        mu += w ( i ) * Coms ( i )->mean();
136                }
[162]137                return mu;
138        }
[229]139        vec variance() const {
140                //non-central moment
[286]141                vec mom2 = zeros ( dim );
[477]142                for ( int i = 0; i < w.length(); i++ ) {
143                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
144                }
[229]145                //central moment
[477]146                return mom2 - pow ( mean(), 2 );
[229]147        }
[211]148        double evallog ( const vec &val ) const {
[162]149                int i;
150                double sum = 0.0;
[477]151                for ( i = 0; i < w.length(); i++ ) {
152                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
153                }
154                if ( sum == 0.0 ) {
155                        sum = std::numeric_limits<double>::epsilon();
156                }
157                double tmp = log ( sum );
[565]158                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[214]159                return tmp;
[162]160        };
[715]161
[713]162        vec evallog_mat ( const mat &Val ) const {
[477]163                vec x = zeros ( Val.cols() );
[192]164                for ( int i = 0; i < w.length(); i++ ) {
[713]165                        x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
[182]166                }
[192]167                return log ( x );
[182]168        };
[715]169
[214]170        //! Auxiliary function that returns pdflog for each component
[715]171        mat evallog_coms ( const mat &Val ) const {
[192]172                mat X ( w.length(), Val.cols() );
173                for ( int i = 0; i < w.length(); i++ ) {
[713]174                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
[189]175                }
176                return X;
177        };
[107]178
[504]179        shared_ptr<epdf> marginal ( const RV &rv ) const;
[536]180        //! Update already existing marginal density  \c target
[504]181        void marginal ( const RV &rv, emix &target ) const;
[693]182        shared_ptr<pdf> condition ( const RV &rv ) const;
[182]183
[713]184        //Access methods
[162]185        //! returns a pointer to the internal mean value. Use with Care!
[477]186        vec& _w() {
187                return w;
188        }
[204]189
[193]190        //!access function
[504]191        shared_ptr<epdf> _Coms ( int i ) {
[477]192                return Coms ( i );
[286]193        }
[504]194
[477]195        void set_rv ( const RV &rv ) {
196                epdf::set_rv ( rv );
197                for ( int i = 0; i < Coms.length(); i++ ) {
198                        Coms ( i )->set_rv ( rv );
199                }
200        }
[713]201
202        //! Load from structure with elements:
203        //!  \code
204        //! { class='emix';
205        //!   pdfs = (..., ...);     // list of pdfs in the mixture
206        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
207        //! }
208        //! \endcode
209        //!@}
210        void from_setting ( const Setting &set ) {
[737]211
212                vec w0;
[716]213                Array<shared_ptr<epdf> > Coms0;
[737]214
[716]215                UI::get ( Coms0, set, "pdfs", UI::compulsory );
[713]216
[737]217                if ( !UI::get ( w0, set, "weights", UI::optional ) ) {
[713]218                        int len = Coms.length();
[737]219                        w0.set_length ( len );
[716]220                        w0 = 1.0 / len;
[713]221                }
[716]222
223                // TODO asi lze nacitat primocare do w a coms, jen co bude hotovy validate()
[737]224                set_parameters ( w0, Coms0 );
[713]225        }
[107]226};
[737]227SHAREDPTR ( emix );
[713]228UIREGISTER ( emix );
[107]229
[333]230/*!
231* \brief Mixture of egiws
232
233*/
234class egiwmix : public egiw {
235protected:
236        //! weights of the components
237        vec w;
238        //! Component (epdfs)
239        Array<egiw*> Coms;
240        //!Flag if owning Coms
241        bool destroyComs;
242public:
243        //!Default constructor
244        egiwmix ( ) : egiw ( ) {};
245
246        //! Set weights \c w and components \c Coms
247        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
[477]248        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
[333]249
250        //!return expected value
251        vec mean() const;
252
253        //!return a sample from the density
254        vec sample() const;
255
256        //!return the expected variance
[477]257        vec variance() const;
[333]258
259        // TODO!!! Defined to follow ANSI and/or for future development
260        void mean_mat ( mat &M, mat&R ) const {};
[477]261        double evallog_nn ( const vec &val ) const {
262                return 0;
263        };
264        double lognc () const {
265                return 0;
[504]266        }
[333]267
[504]268        shared_ptr<epdf> marginal ( const RV &rv ) const;
[660]269        //! marginal density update
[504]270        void marginal ( const RV &rv, emix &target ) const;
271
[333]272//Access methods
273        //! returns a pointer to the internal mean value. Use with Care!
[477]274        vec& _w() {
275                return w;
276        }
277        virtual ~egiwmix() {
278                if ( destroyComs ) {
279                        for ( int i = 0; i < Coms.length(); i++ ) {
280                                delete Coms ( i );
281                        }
282                }
283        }
[333]284        //! Auxiliary function for taking ownership of the Coms()
[477]285        void ownComs() {
286                destroyComs = true;
287        }
[333]288
289        //!access function
[477]290        egiw* _Coms ( int i ) {
291                return Coms ( i );
292        }
[333]293
[477]294        void set_rv ( const RV &rv ) {
295                egiw::set_rv ( rv );
296                for ( int i = 0; i < Coms.length(); i++ ) {
297                        Coms ( i )->set_rv ( rv );
298                }
[333]299        }
300
301        //! Approximation of a GiW mix by a single GiW pdf
302        egiw* approx();
303};
304
[115]305/*! \brief Chain rule decomposition of epdf
306
[145]307Probability density in the form of Chain-rule decomposition:
308\[
309f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
310\]
311Note that
[115]312*/
[693]313class mprod: public pdf {
[507]314private:
[693]315        Array<shared_ptr<pdf> > pdfs;
[507]316
[693]317        //! Data link for each pdfs
[546]318        Array<shared_ptr<datalink_m2m> > dls;
[461]319
[546]320protected:
[487]321        //! dummy epdf used only as storage for RV and dim
322        epdf iepdf;
[461]323
[162]324public:
[507]325        //! \brief Default constructor
326        mprod() { }
327
328        /*!\brief Constructor from list of mFacs
[165]329        */
[693]330        mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
[477]331                set_elements ( mFacs );
[461]332        }
[693]333        //! Set internal \c pdfs from given values
[737]334        void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
[477]335
[211]336        double evallogcond ( const vec &val, const vec &cond ) {
[162]337                int i;
[270]338                double res = 0.0;
[693]339                for ( i = pdfs.length() - 1; i >= 0; i-- ) {
340                        /*                      if ( pdfs(i)->_rvc().count() >0) {
341                                                        pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
[193]342                                                }
343                                                // add logarithms
[270]344                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
[693]345                        res += pdfs ( i )->evallogcond (
[270]346                                   dls ( i )->pushdown ( val ),
[204]347                                   dls ( i )->get_cond ( val, cond )
348                               );
[145]349                }
[193]350                return res;
[162]351        }
[713]352        vec evallogcond_mat ( const mat &Dt, const vec &cond ) {
[477]353                vec tmp ( Dt.cols() );
354                for ( int i = 0; i < Dt.cols(); i++ ) {
355                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
[395]356                }
357                return tmp;
358        };
[713]359        vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
[477]360                vec tmp ( Dt.length() );
361                for ( int i = 0; i < Dt.length(); i++ ) {
362                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
[395]363                }
[477]364                return tmp;
[395]365        };
366
367
[270]368        //TODO smarter...
369        vec samplecond ( const vec &cond ) {
[477]370                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
[487]371                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
[165]372                vec smpi;
[192]373                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
[693]374                for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
375                        // generate contribution of this pdf
[737]376                        smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
[162]377                        // copy contribution of this pdf into smp
[270]378                        dls ( i )->pushup ( smp, smpi );
[145]379                }
[162]380                return smp;
381        }
382
[395]383        //! Load from structure with elements:
384        //!  \code
385        //! { class='mprod';
[693]386        //!   pdfs = (..., ...);     // list of pdfs in the order of chain rule
[395]387        //! }
388        //! \endcode
389        //!@}
[477]390        void from_setting ( const Setting &set ) {
[693]391                Array<shared_ptr<pdf> > atmp; //temporary Array
392                UI::get ( atmp, set, "pdfs", UI::compulsory );
[527]393                set_elements ( atmp );
[395]394        }
[107]395};
[477]396UIREGISTER ( mprod );
[529]397SHAREDPTR ( mprod );
[107]398
[168]399//! Product of independent epdfs. For dependent pdfs, use mprod.
400class eprod: public epdf {
401protected:
402        //! Components (epdfs)
[170]403        Array<const epdf*> epdfs;
[168]404        //! Array of indeces
[270]405        Array<datalink*> dls;
[168]406public:
[536]407        //! Default constructor
[477]408        eprod () : epdfs ( 0 ), dls ( 0 ) {};
[737]409        //! Set internal
[477]410        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
411                epdfs = epdfs0;//.set_length ( epdfs0.length() );
[286]412                dls.set_length ( epdfs.length() );
413
[477]414                bool independent = true;
[286]415                if ( named ) {
[477]416                        for ( int i = 0; i < epdfs.length(); i++ ) {
417                                independent = rv.add ( epdfs ( i )->_rv() );
[565]418                                bdm_assert_debug ( independent, "eprod:: given components are not independent." );
[286]419                        }
[477]420                        dim = rv._dsize();
421                } else {
422                        dim = 0;
423                        for ( int i = 0; i < epdfs.length(); i++ ) {
424                                dim += epdfs ( i )->dimension();
[286]425                        }
[204]426                }
[286]427                //
[477]428                int cumdim = 0;
429                int dimi = 0;
[286]430                int i;
[477]431                for ( i = 0; i < epdfs.length(); i++ ) {
[286]432                        dls ( i ) = new datalink;
[477]433                        if ( named ) {
434                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
435                        } else {
[286]436                                dimi = epdfs ( i )->dimension();
[477]437                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
438                                cumdim += dimi;
[286]439                        }
440                }
[168]441        }
442
443        vec mean() const {
[270]444                vec tmp ( dim );
[477]445                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]446                        vec pom = epdfs ( i )->mean();
[270]447                        dls ( i )->pushup ( tmp, pom );
[168]448                }
449                return tmp;
450        }
[229]451        vec variance() const {
[270]452                vec tmp ( dim ); //second moment
[477]453                for ( int i = 0; i < epdfs.length(); i++ ) {
[229]454                        vec pom = epdfs ( i )->mean();
[477]455                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
[229]456                }
[477]457                return tmp - pow ( mean(), 2 );
[229]458        }
[168]459        vec sample() const {
[270]460                vec tmp ( dim );
[477]461                for ( int i = 0; i < epdfs.length(); i++ ) {
[168]462                        vec pom = epdfs ( i )->sample();
[270]463                        dls ( i )->pushup ( tmp, pom );
[168]464                }
465                return tmp;
466        }
[211]467        double evallog ( const vec &val ) const {
[477]468                double tmp = 0;
469                for ( int i = 0; i < epdfs.length(); i++ ) {
470                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
[168]471                }
[565]472                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
[168]473                return tmp;
474        }
[170]475        //!access function
[477]476        const epdf* operator () ( int i ) const {
[565]477                bdm_assert_debug ( i < epdfs.length(), "wrong index" );
[477]478                return epdfs ( i );
479        }
[204]480
[193]481        //!Destructor
[477]482        ~eprod() {
483                for ( int i = 0; i < epdfs.length(); i++ ) {
484                        delete dls ( i );
485                }
486        }
[168]487};
488
489
[693]490/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
[124]491
492*/
[693]493class mmix : public pdf {
[488]494protected:
[693]495        //! Component (pdfs)
496        Array<shared_ptr<pdf> > Coms;
[488]497        //!weights of the components
498        vec w;
499public:
500        //!Default constructor
[737]501        mmix() : Coms ( 0 ) { }
[461]502
[488]503        //! Set weights \c w and components \c R
[693]504        void set_parameters ( const vec &w0, const Array<shared_ptr<pdf> > &Coms0 ) {
[536]505                //!\todo check if all components are OK
[488]506                Coms = Coms0;
[737]507                w = w0;
[503]508
[737]509                if ( Coms0.length() > 0 ) {
510                        set_rv ( Coms ( 0 )->_rv() );
[675]511                        dim = rv._dsize();
[737]512                        set_rvc ( Coms ( 0 )->_rvc() );
[513]513                        dimc = rvc._dsize();
514                }
[488]515        }
[737]516        double evallogcond ( const vec &dt, const vec &cond ) {
517                double ll = 0.0;
518                for ( int i = 0; i < Coms.length(); i++ ) {
519                        ll += Coms ( i )->evallogcond ( dt, cond );
[488]520                }
521                return ll;
522        }
523
[737]524        vec samplecond ( const vec &cond );
[488]525
[711]526        //! Load from structure with elements:
527        //!  \code
528        //! { class='mmix';
529        //!   pdfs = (..., ...);     // list of pdfs in the mixture
530        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
531        //! }
532        //! \endcode
533        //!@}
534        void from_setting ( const Setting &set ) {
[737]535                UI::get ( Coms, set, "pdfs", UI::compulsory );
[711]536
[716]537                // TODO ma byt zde, ci ve validate()?
[737]538                if ( Coms.length() > 0 ) {
539                        set_rv ( Coms ( 0 )->_rv() );
[716]540                        dim = rv._dsize();
[737]541                        set_rvc ( Coms ( 0 )->_rvc() );
[716]542                        dimc = rvc._dsize();
543                }
544
[737]545                if ( !UI::get ( w, set, "weights", UI::optional ) ) {
[711]546                        int len = Coms.length();
[737]547                        w.set_length ( len );
[715]548                        w = 1.0 / len;
[711]549                }
550        }
[488]551};
[737]552SHAREDPTR ( mmix );
[711]553UIREGISTER ( mmix );
[488]554
[254]555}
[107]556#endif //MX_H
Note: See TracBrowser for help on using the browser.