root/library/bdm/stat/emix.h @ 559

Revision 559, 12.1 kB (checked in by vbarta, 15 years ago)

updated obsolete comment

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93
94private:
95        // not implemented
96        mratio ( const mratio & );
97        mratio &operator=( const mratio & );
98};
99
100/*!
101* \brief Mixture of epdfs
102
103Density function:
104\f[
105f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
106\f]
107where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
108
109*/
110class emix : public epdf {
111protected:
112        //! weights of the components
113        vec w;
114
115        //! Component (epdfs)
116        Array<shared_ptr<epdf> > Coms;
117
118public:
119        //! Default constructor
120        emix ( ) : epdf ( ) { }
121
122          /*!
123            \brief Set weights \c w and components \c Coms
124
125            Shared pointers in Coms are kept inside this instance and
126            shouldn't be modified after being passed to this method.
127          */
128        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
129
130        vec sample() const;
131        vec mean() const {
132                int i;
133                vec mu = zeros ( dim );
134                for ( i = 0; i < w.length(); i++ ) {
135                        mu += w ( i ) * Coms ( i )->mean();
136                }
137                return mu;
138        }
139        vec variance() const {
140                //non-central moment
141                vec mom2 = zeros ( dim );
142                for ( int i = 0; i < w.length(); i++ ) {
143                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
144                }
145                //central moment
146                return mom2 - pow ( mean(), 2 );
147        }
148        double evallog ( const vec &val ) const {
149                int i;
150                double sum = 0.0;
151                for ( i = 0; i < w.length(); i++ ) {
152                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
153                }
154                if ( sum == 0.0 ) {
155                        sum = std::numeric_limits<double>::epsilon();
156                }
157                double tmp = log ( sum );
158                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
159                return tmp;
160        };
161        vec evallog_m ( const mat &Val ) const {
162                vec x = zeros ( Val.cols() );
163                for ( int i = 0; i < w.length(); i++ ) {
164                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
165                }
166                return log ( x );
167        };
168        //! Auxiliary function that returns pdflog for each component
169        mat evallog_M ( const mat &Val ) const {
170                mat X ( w.length(), Val.cols() );
171                for ( int i = 0; i < w.length(); i++ ) {
172                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
173                }
174                return X;
175        };
176
177        shared_ptr<epdf> marginal ( const RV &rv ) const;
178        //! Update already existing marginal density  \c target
179        void marginal ( const RV &rv, emix &target ) const;
180        shared_ptr<mpdf> condition ( const RV &rv ) const;
181
182//Access methods
183        //! returns a pointer to the internal mean value. Use with Care!
184        vec& _w() {
185                return w;
186        }
187
188        //!access function
189        shared_ptr<epdf> _Coms ( int i ) {
190                return Coms ( i );
191        }
192
193        void set_rv ( const RV &rv ) {
194                epdf::set_rv ( rv );
195                for ( int i = 0; i < Coms.length(); i++ ) {
196                        Coms ( i )->set_rv ( rv );
197                }
198        }
199};
200SHAREDPTR( emix );
201
202/*!
203* \brief Mixture of egiws
204
205*/
206class egiwmix : public egiw {
207protected:
208        //! weights of the components
209        vec w;
210        //! Component (epdfs)
211        Array<egiw*> Coms;
212        //!Flag if owning Coms
213        bool destroyComs;
214public:
215        //!Default constructor
216        egiwmix ( ) : egiw ( ) {};
217
218        //! Set weights \c w and components \c Coms
219        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
220        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
221
222        //!return expected value
223        vec mean() const;
224
225        //!return a sample from the density
226        vec sample() const;
227
228        //!return the expected variance
229        vec variance() const;
230
231        // TODO!!! Defined to follow ANSI and/or for future development
232        void mean_mat ( mat &M, mat&R ) const {};
233        double evallog_nn ( const vec &val ) const {
234                return 0;
235        };
236        double lognc () const {
237                return 0;
238        }
239
240        shared_ptr<epdf> marginal ( const RV &rv ) const;
241        void marginal ( const RV &rv, emix &target ) const;
242
243//Access methods
244        //! returns a pointer to the internal mean value. Use with Care!
245        vec& _w() {
246                return w;
247        }
248        virtual ~egiwmix() {
249                if ( destroyComs ) {
250                        for ( int i = 0; i < Coms.length(); i++ ) {
251                                delete Coms ( i );
252                        }
253                }
254        }
255        //! Auxiliary function for taking ownership of the Coms()
256        void ownComs() {
257                destroyComs = true;
258        }
259
260        //!access function
261        egiw* _Coms ( int i ) {
262                return Coms ( i );
263        }
264
265        void set_rv ( const RV &rv ) {
266                egiw::set_rv ( rv );
267                for ( int i = 0; i < Coms.length(); i++ ) {
268                        Coms ( i )->set_rv ( rv );
269                }
270        }
271
272        //! Approximation of a GiW mix by a single GiW pdf
273        egiw* approx();
274};
275
276/*! \brief Chain rule decomposition of epdf
277
278Probability density in the form of Chain-rule decomposition:
279\[
280f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
281\]
282Note that
283*/
284class mprod: public mpdf {
285private:
286        Array<shared_ptr<mpdf> > mpdfs;
287
288        //! Data link for each mpdfs
289        Array<shared_ptr<datalink_m2m> > dls;
290
291protected:
292        //! dummy epdf used only as storage for RV and dim
293        epdf iepdf;
294
295public:
296        //! \brief Default constructor
297        mprod() { }
298
299        /*!\brief Constructor from list of mFacs
300        */
301        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
302                set_elements ( mFacs );
303        }
304        //! Set internal \c mpdfs from given values
305        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
306
307        double evallogcond ( const vec &val, const vec &cond ) {
308                int i;
309                double res = 0.0;
310                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
311                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
312                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
313                                                }
314                                                // add logarithms
315                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
316                        res += mpdfs ( i )->evallogcond (
317                                   dls ( i )->pushdown ( val ),
318                                   dls ( i )->get_cond ( val, cond )
319                               );
320                }
321                return res;
322        }
323        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
324                vec tmp ( Dt.cols() );
325                for ( int i = 0; i < Dt.cols(); i++ ) {
326                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
327                }
328                return tmp;
329        };
330        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
331                vec tmp ( Dt.length() );
332                for ( int i = 0; i < Dt.length(); i++ ) {
333                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
334                }
335                return tmp;
336        };
337
338
339        //TODO smarter...
340        vec samplecond ( const vec &cond ) {
341                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
342                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
343                vec smpi;
344                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
345                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
346                        // generate contribution of this mpdf
347                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
348                        // copy contribution of this pdf into smp
349                        dls ( i )->pushup ( smp, smpi );
350                }
351                return smp;
352        }
353
354        //! Load from structure with elements:
355        //!  \code
356        //! { class='mprod';
357        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
358        //! }
359        //! \endcode
360        //!@}
361        void from_setting ( const Setting &set ) {
362                Array<shared_ptr<mpdf> > atmp; //temporary Array
363                UI::get ( atmp, set, "mpdfs", UI::compulsory );
364                set_elements ( atmp );
365        }
366};
367UIREGISTER ( mprod );
368SHAREDPTR ( mprod );
369
370//! Product of independent epdfs. For dependent pdfs, use mprod.
371class eprod: public epdf {
372protected:
373        //! Components (epdfs)
374        Array<const epdf*> epdfs;
375        //! Array of indeces
376        Array<datalink*> dls;
377public:
378        //! Default constructor
379        eprod () : epdfs ( 0 ), dls ( 0 ) {};
380        //! Set internal
381        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
382                epdfs = epdfs0;//.set_length ( epdfs0.length() );
383                dls.set_length ( epdfs.length() );
384
385                bool independent = true;
386                if ( named ) {
387                        for ( int i = 0; i < epdfs.length(); i++ ) {
388                                independent = rv.add ( epdfs ( i )->_rv() );
389                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
390                        }
391                        dim = rv._dsize();
392                } else {
393                        dim = 0;
394                        for ( int i = 0; i < epdfs.length(); i++ ) {
395                                dim += epdfs ( i )->dimension();
396                        }
397                }
398                //
399                int cumdim = 0;
400                int dimi = 0;
401                int i;
402                for ( i = 0; i < epdfs.length(); i++ ) {
403                        dls ( i ) = new datalink;
404                        if ( named ) {
405                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
406                        } else {
407                                dimi = epdfs ( i )->dimension();
408                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
409                                cumdim += dimi;
410                        }
411                }
412        }
413
414        vec mean() const {
415                vec tmp ( dim );
416                for ( int i = 0; i < epdfs.length(); i++ ) {
417                        vec pom = epdfs ( i )->mean();
418                        dls ( i )->pushup ( tmp, pom );
419                }
420                return tmp;
421        }
422        vec variance() const {
423                vec tmp ( dim ); //second moment
424                for ( int i = 0; i < epdfs.length(); i++ ) {
425                        vec pom = epdfs ( i )->mean();
426                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
427                }
428                return tmp - pow ( mean(), 2 );
429        }
430        vec sample() const {
431                vec tmp ( dim );
432                for ( int i = 0; i < epdfs.length(); i++ ) {
433                        vec pom = epdfs ( i )->sample();
434                        dls ( i )->pushup ( tmp, pom );
435                }
436                return tmp;
437        }
438        double evallog ( const vec &val ) const {
439                double tmp = 0;
440                for ( int i = 0; i < epdfs.length(); i++ ) {
441                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
442                }
443                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
444                return tmp;
445        }
446        //!access function
447        const epdf* operator () ( int i ) const {
448                it_assert_debug ( i < epdfs.length(), "wrong index" );
449                return epdfs ( i );
450        }
451
452        //!Destructor
453        ~eprod() {
454                for ( int i = 0; i < epdfs.length(); i++ ) {
455                        delete dls ( i );
456                }
457        }
458};
459
460
461/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
462
463*/
464class mmix : public mpdf {
465protected:
466        //! Component (mpdfs)
467        Array<shared_ptr<mpdf> > Coms;
468        //!weights of the components
469        vec w;
470        //! dummy epdfs
471        epdf dummy_epdf;
472public:
473        //!Default constructor
474        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
475
476        //! Set weights \c w and components \c R
477        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
478                //!\todo check if all components are OK
479                Coms = Coms0;
480                w=w0;   
481
482                if (Coms0.length()>0){
483                        set_rv(Coms(0)->_rv());
484                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
485                        set_rvc(Coms(0)->_rvc());
486                        dimc = rvc._dsize();
487                }
488        }
489        double evallogcond (const vec &dt, const vec &cond) {
490                double ll=0.0;
491                for (int i=0;i<Coms.length();i++){
492                        ll+=Coms(i)->evallogcond(dt,cond);
493                }
494                return ll;
495        }
496
497        vec samplecond (const vec &cond);
498
499};
500
501}
502#endif //MX_H
Note: See TracBrowser for help on using the browser.