root/library/bdm/stat/emix.h @ 660

Revision 660, 12.1 kB (checked in by smidl, 15 years ago)

doc - doxygen warnings

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        bdm_error ( "todo" );
65                        // destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93
94private:
95        // not implemented
96        mratio ( const mratio & );
97        mratio &operator=( const mratio & );
98};
99
100/*!
101* \brief Mixture of epdfs
102
103Density function:
104\f[
105f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
106\f]
107where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
108
109*/
110class emix : public epdf {
111protected:
112        //! weights of the components
113        vec w;
114
115        //! Component (epdfs)
116        Array<shared_ptr<epdf> > Coms;
117
118public:
119        //! Default constructor
120        emix ( ) : epdf ( ) { }
121
122          /*!
123            \brief Set weights \c w and components \c Coms
124
125            Shared pointers in Coms are kept inside this instance and
126            shouldn't be modified after being passed to this method.
127          */
128        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
129
130        vec sample() const;
131        vec mean() const {
132                int i;
133                vec mu = zeros ( dim );
134                for ( i = 0; i < w.length(); i++ ) {
135                        mu += w ( i ) * Coms ( i )->mean();
136                }
137                return mu;
138        }
139        vec variance() const {
140                //non-central moment
141                vec mom2 = zeros ( dim );
142                for ( int i = 0; i < w.length(); i++ ) {
143                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
144                }
145                //central moment
146                return mom2 - pow ( mean(), 2 );
147        }
148        double evallog ( const vec &val ) const {
149                int i;
150                double sum = 0.0;
151                for ( i = 0; i < w.length(); i++ ) {
152                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
153                }
154                if ( sum == 0.0 ) {
155                        sum = std::numeric_limits<double>::epsilon();
156                }
157                double tmp = log ( sum );
158                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
159                return tmp;
160        };
161        vec evallog_m ( const mat &Val ) const {
162                vec x = zeros ( Val.cols() );
163                for ( int i = 0; i < w.length(); i++ ) {
164                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
165                }
166                return log ( x );
167        };
168        //! Auxiliary function that returns pdflog for each component
169        mat evallog_M ( const mat &Val ) const {
170                mat X ( w.length(), Val.cols() );
171                for ( int i = 0; i < w.length(); i++ ) {
172                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
173                }
174                return X;
175        };
176
177        shared_ptr<epdf> marginal ( const RV &rv ) const;
178        //! Update already existing marginal density  \c target
179        void marginal ( const RV &rv, emix &target ) const;
180        shared_ptr<mpdf> condition ( const RV &rv ) const;
181
182//Access methods
183        //! returns a pointer to the internal mean value. Use with Care!
184        vec& _w() {
185                return w;
186        }
187
188        //!access function
189        shared_ptr<epdf> _Coms ( int i ) {
190                return Coms ( i );
191        }
192
193        void set_rv ( const RV &rv ) {
194                epdf::set_rv ( rv );
195                for ( int i = 0; i < Coms.length(); i++ ) {
196                        Coms ( i )->set_rv ( rv );
197                }
198        }
199};
200SHAREDPTR( emix );
201
202/*!
203* \brief Mixture of egiws
204
205*/
206class egiwmix : public egiw {
207protected:
208        //! weights of the components
209        vec w;
210        //! Component (epdfs)
211        Array<egiw*> Coms;
212        //!Flag if owning Coms
213        bool destroyComs;
214public:
215        //!Default constructor
216        egiwmix ( ) : egiw ( ) {};
217
218        //! Set weights \c w and components \c Coms
219        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
220        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
221
222        //!return expected value
223        vec mean() const;
224
225        //!return a sample from the density
226        vec sample() const;
227
228        //!return the expected variance
229        vec variance() const;
230
231        // TODO!!! Defined to follow ANSI and/or for future development
232        void mean_mat ( mat &M, mat&R ) const {};
233        double evallog_nn ( const vec &val ) const {
234                return 0;
235        };
236        double lognc () const {
237                return 0;
238        }
239
240        shared_ptr<epdf> marginal ( const RV &rv ) const;
241        //! marginal density update
242        void marginal ( const RV &rv, emix &target ) const;
243
244//Access methods
245        //! returns a pointer to the internal mean value. Use with Care!
246        vec& _w() {
247                return w;
248        }
249        virtual ~egiwmix() {
250                if ( destroyComs ) {
251                        for ( int i = 0; i < Coms.length(); i++ ) {
252                                delete Coms ( i );
253                        }
254                }
255        }
256        //! Auxiliary function for taking ownership of the Coms()
257        void ownComs() {
258                destroyComs = true;
259        }
260
261        //!access function
262        egiw* _Coms ( int i ) {
263                return Coms ( i );
264        }
265
266        void set_rv ( const RV &rv ) {
267                egiw::set_rv ( rv );
268                for ( int i = 0; i < Coms.length(); i++ ) {
269                        Coms ( i )->set_rv ( rv );
270                }
271        }
272
273        //! Approximation of a GiW mix by a single GiW pdf
274        egiw* approx();
275};
276
277/*! \brief Chain rule decomposition of epdf
278
279Probability density in the form of Chain-rule decomposition:
280\[
281f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
282\]
283Note that
284*/
285class mprod: public mpdf {
286private:
287        Array<shared_ptr<mpdf> > mpdfs;
288
289        //! Data link for each mpdfs
290        Array<shared_ptr<datalink_m2m> > dls;
291
292protected:
293        //! dummy epdf used only as storage for RV and dim
294        epdf iepdf;
295
296public:
297        //! \brief Default constructor
298        mprod() { }
299
300        /*!\brief Constructor from list of mFacs
301        */
302        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
303                set_elements ( mFacs );
304        }
305        //! Set internal \c mpdfs from given values
306        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
307
308        double evallogcond ( const vec &val, const vec &cond ) {
309                int i;
310                double res = 0.0;
311                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
312                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
313                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
314                                                }
315                                                // add logarithms
316                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
317                        res += mpdfs ( i )->evallogcond (
318                                   dls ( i )->pushdown ( val ),
319                                   dls ( i )->get_cond ( val, cond )
320                               );
321                }
322                return res;
323        }
324        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
325                vec tmp ( Dt.cols() );
326                for ( int i = 0; i < Dt.cols(); i++ ) {
327                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
328                }
329                return tmp;
330        };
331        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
332                vec tmp ( Dt.length() );
333                for ( int i = 0; i < Dt.length(); i++ ) {
334                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
335                }
336                return tmp;
337        };
338
339
340        //TODO smarter...
341        vec samplecond ( const vec &cond ) {
342                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
343                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
344                vec smpi;
345                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
346                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
347                        // generate contribution of this mpdf
348                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
349                        // copy contribution of this pdf into smp
350                        dls ( i )->pushup ( smp, smpi );
351                }
352                return smp;
353        }
354
355        //! Load from structure with elements:
356        //!  \code
357        //! { class='mprod';
358        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
359        //! }
360        //! \endcode
361        //!@}
362        void from_setting ( const Setting &set ) {
363                Array<shared_ptr<mpdf> > atmp; //temporary Array
364                UI::get ( atmp, set, "mpdfs", UI::compulsory );
365                set_elements ( atmp );
366        }
367};
368UIREGISTER ( mprod );
369SHAREDPTR ( mprod );
370
371//! Product of independent epdfs. For dependent pdfs, use mprod.
372class eprod: public epdf {
373protected:
374        //! Components (epdfs)
375        Array<const epdf*> epdfs;
376        //! Array of indeces
377        Array<datalink*> dls;
378public:
379        //! Default constructor
380        eprod () : epdfs ( 0 ), dls ( 0 ) {};
381        //! Set internal
382        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
383                epdfs = epdfs0;//.set_length ( epdfs0.length() );
384                dls.set_length ( epdfs.length() );
385
386                bool independent = true;
387                if ( named ) {
388                        for ( int i = 0; i < epdfs.length(); i++ ) {
389                                independent = rv.add ( epdfs ( i )->_rv() );
390                                bdm_assert_debug ( independent, "eprod:: given components are not independent." );
391                        }
392                        dim = rv._dsize();
393                } else {
394                        dim = 0;
395                        for ( int i = 0; i < epdfs.length(); i++ ) {
396                                dim += epdfs ( i )->dimension();
397                        }
398                }
399                //
400                int cumdim = 0;
401                int dimi = 0;
402                int i;
403                for ( i = 0; i < epdfs.length(); i++ ) {
404                        dls ( i ) = new datalink;
405                        if ( named ) {
406                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
407                        } else {
408                                dimi = epdfs ( i )->dimension();
409                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
410                                cumdim += dimi;
411                        }
412                }
413        }
414
415        vec mean() const {
416                vec tmp ( dim );
417                for ( int i = 0; i < epdfs.length(); i++ ) {
418                        vec pom = epdfs ( i )->mean();
419                        dls ( i )->pushup ( tmp, pom );
420                }
421                return tmp;
422        }
423        vec variance() const {
424                vec tmp ( dim ); //second moment
425                for ( int i = 0; i < epdfs.length(); i++ ) {
426                        vec pom = epdfs ( i )->mean();
427                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
428                }
429                return tmp - pow ( mean(), 2 );
430        }
431        vec sample() const {
432                vec tmp ( dim );
433                for ( int i = 0; i < epdfs.length(); i++ ) {
434                        vec pom = epdfs ( i )->sample();
435                        dls ( i )->pushup ( tmp, pom );
436                }
437                return tmp;
438        }
439        double evallog ( const vec &val ) const {
440                double tmp = 0;
441                for ( int i = 0; i < epdfs.length(); i++ ) {
442                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
443                }
444                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
445                return tmp;
446        }
447        //!access function
448        const epdf* operator () ( int i ) const {
449                bdm_assert_debug ( i < epdfs.length(), "wrong index" );
450                return epdfs ( i );
451        }
452
453        //!Destructor
454        ~eprod() {
455                for ( int i = 0; i < epdfs.length(); i++ ) {
456                        delete dls ( i );
457                }
458        }
459};
460
461
462/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
463
464*/
465class mmix : public mpdf {
466protected:
467        //! Component (mpdfs)
468        Array<shared_ptr<mpdf> > Coms;
469        //!weights of the components
470        vec w;
471        //! dummy epdfs
472        epdf dummy_epdf;
473public:
474        //!Default constructor
475        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
476
477        //! Set weights \c w and components \c R
478        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
479                //!\todo check if all components are OK
480                Coms = Coms0;
481                w=w0;   
482
483                if (Coms0.length()>0){
484                        set_rv(Coms(0)->_rv());
485                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
486                        set_rvc(Coms(0)->_rvc());
487                        dimc = rvc._dsize();
488                }
489        }
490        double evallogcond (const vec &dt, const vec &cond) {
491                double ll=0.0;
492                for (int i=0;i<Coms.length();i++){
493                        ll+=Coms(i)->evallogcond(dt,cond);
494                }
495                return ll;
496        }
497
498        vec samplecond (const vec &cond);
499
500};
501
502}
503#endif //MX_H
Note: See TracBrowser for help on using the browser.