root/library/bdm/stat/emix.h @ 716

Revision 716, 13.2 kB (checked in by mido, 15 years ago)

clean up within testsuite - pdf_harness is used even in the case of emix and mmix test

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public pdf {
38protected:
39        //! Nominator in the form of pdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49public:
50        //!Default constructor. By default, the given epdf is not copied!
51        //! It is assumed that this function will be used only temporarily.
52        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
53                // adjust rv and rvc
54
55                set_rv( rv ); // TODO co kdyby tohle samo uz nastavovalo dimension?!?!
56                dim = rv._dsize();
57
58                rvc = nom0->_rv().subt ( rv );
59                dimc = rvc._dsize();
60       
61                //prepare data structures
62                if ( copy ) {
63                        bdm_error ( "todo" );
64                        // destroynom = true;
65                } else {
66                        nom = nom0;
67                        destroynom = false;
68                }
69                bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
70
71                // build denominator
72                den = nom->marginal ( rvc );
73                dl.set_connection ( rv, rvc, nom0->_rv() );
74        };
75
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93
94private:
95        // not implemented
96        mratio ( const mratio & );
97        mratio &operator=( const mratio & );
98};
99
100/*!
101* \brief Mixture of epdfs
102
103Density function:
104\f[
105f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
106\f]
107where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
108
109*/
110class emix : public epdf {
111protected:
112        //! weights of the components
113        vec w;
114
115        //! Component (epdfs)
116        Array<shared_ptr<epdf> > Coms;
117
118public:
119        //! Default constructor
120        emix ( ) : epdf ( ) { }
121
122          /*!
123            \brief Set weights \c w and components \c Coms
124
125            Shared pointers in Coms are kept inside this instance and
126            shouldn't be modified after being passed to this method.
127          */
128        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
129
130        vec sample() const;
131        vec mean() const {
132                int i;
133                vec mu = zeros ( dim );
134                for ( i = 0; i < w.length(); i++ ) {
135                        mu += w ( i ) * Coms ( i )->mean();
136                }
137                return mu;
138        }
139        vec variance() const {
140                //non-central moment
141                vec mom2 = zeros ( dim );
142                for ( int i = 0; i < w.length(); i++ ) {
143                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
144                }
145                //central moment
146                return mom2 - pow ( mean(), 2 );
147        }
148        double evallog ( const vec &val ) const {
149                int i;
150                double sum = 0.0;
151                for ( i = 0; i < w.length(); i++ ) {
152                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
153                }
154                if ( sum == 0.0 ) {
155                        sum = std::numeric_limits<double>::epsilon();
156                }
157                double tmp = log ( sum );
158                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
159                return tmp;
160        };
161
162        vec evallog_mat ( const mat &Val ) const {
163                vec x = zeros ( Val.cols() );
164                for ( int i = 0; i < w.length(); i++ ) {
165                        x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
166                }
167                return log ( x );
168        };
169
170        //! Auxiliary function that returns pdflog for each component
171        mat evallog_coms ( const mat &Val ) const {
172                mat X ( w.length(), Val.cols() );
173                for ( int i = 0; i < w.length(); i++ ) {
174                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
175                }
176                return X;
177        };
178
179        shared_ptr<epdf> marginal ( const RV &rv ) const;
180        //! Update already existing marginal density  \c target
181        void marginal ( const RV &rv, emix &target ) const;
182        shared_ptr<pdf> condition ( const RV &rv ) const;
183
184        //Access methods
185        //! returns a pointer to the internal mean value. Use with Care!
186        vec& _w() {
187                return w;
188        }
189
190        //!access function
191        shared_ptr<epdf> _Coms ( int i ) {
192                return Coms ( i );
193        }
194
195        void set_rv ( const RV &rv ) {
196                epdf::set_rv ( rv );
197                for ( int i = 0; i < Coms.length(); i++ ) {
198                        Coms ( i )->set_rv ( rv );
199                }
200        }
201
202        //! Load from structure with elements:
203        //!  \code
204        //! { class='emix';
205        //!   pdfs = (..., ...);     // list of pdfs in the mixture
206        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
207        //! }
208        //! \endcode
209        //!@}
210        void from_setting ( const Setting &set ) {
211       
212                vec w0;                 
213                Array<shared_ptr<epdf> > Coms0;
214       
215                UI::get ( Coms0, set, "pdfs", UI::compulsory );
216
217                if( !UI::get( w0, set, "weights", UI::optional ) )
218                {
219                        int len = Coms.length();
220                        w0.set_length( len );
221                        w0 = 1.0 / len;
222                }
223
224                // TODO asi lze nacitat primocare do w a coms, jen co bude hotovy validate()
225                set_parameters( w0, Coms0 );
226        }
227};
228SHAREDPTR( emix );
229UIREGISTER ( emix );
230
231/*!
232* \brief Mixture of egiws
233
234*/
235class egiwmix : public egiw {
236protected:
237        //! weights of the components
238        vec w;
239        //! Component (epdfs)
240        Array<egiw*> Coms;
241        //!Flag if owning Coms
242        bool destroyComs;
243public:
244        //!Default constructor
245        egiwmix ( ) : egiw ( ) {};
246
247        //! Set weights \c w and components \c Coms
248        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
249        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
250
251        //!return expected value
252        vec mean() const;
253
254        //!return a sample from the density
255        vec sample() const;
256
257        //!return the expected variance
258        vec variance() const;
259
260        // TODO!!! Defined to follow ANSI and/or for future development
261        void mean_mat ( mat &M, mat&R ) const {};
262        double evallog_nn ( const vec &val ) const {
263                return 0;
264        };
265        double lognc () const {
266                return 0;
267        }
268
269        shared_ptr<epdf> marginal ( const RV &rv ) const;
270        //! marginal density update
271        void marginal ( const RV &rv, emix &target ) const;
272
273//Access methods
274        //! returns a pointer to the internal mean value. Use with Care!
275        vec& _w() {
276                return w;
277        }
278        virtual ~egiwmix() {
279                if ( destroyComs ) {
280                        for ( int i = 0; i < Coms.length(); i++ ) {
281                                delete Coms ( i );
282                        }
283                }
284        }
285        //! Auxiliary function for taking ownership of the Coms()
286        void ownComs() {
287                destroyComs = true;
288        }
289
290        //!access function
291        egiw* _Coms ( int i ) {
292                return Coms ( i );
293        }
294
295        void set_rv ( const RV &rv ) {
296                egiw::set_rv ( rv );
297                for ( int i = 0; i < Coms.length(); i++ ) {
298                        Coms ( i )->set_rv ( rv );
299                }
300        }
301
302        //! Approximation of a GiW mix by a single GiW pdf
303        egiw* approx();
304};
305
306/*! \brief Chain rule decomposition of epdf
307
308Probability density in the form of Chain-rule decomposition:
309\[
310f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
311\]
312Note that
313*/
314class mprod: public pdf {
315private:
316        Array<shared_ptr<pdf> > pdfs;
317
318        //! Data link for each pdfs
319        Array<shared_ptr<datalink_m2m> > dls;
320
321protected:
322        //! dummy epdf used only as storage for RV and dim
323        epdf iepdf;
324
325public:
326        //! \brief Default constructor
327        mprod() { }
328
329        /*!\brief Constructor from list of mFacs
330        */
331        mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
332                set_elements ( mFacs );
333        }
334        //! Set internal \c pdfs from given values
335        void set_elements (const Array<shared_ptr<pdf> > &mFacs );
336
337        double evallogcond ( const vec &val, const vec &cond ) {
338                int i;
339                double res = 0.0;
340                for ( i = pdfs.length() - 1; i >= 0; i-- ) {
341                        /*                      if ( pdfs(i)->_rvc().count() >0) {
342                                                        pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
343                                                }
344                                                // add logarithms
345                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
346                        res += pdfs ( i )->evallogcond (
347                                   dls ( i )->pushdown ( val ),
348                                   dls ( i )->get_cond ( val, cond )
349                               );
350                }
351                return res;
352        }
353        vec evallogcond_mat ( const mat &Dt, const vec &cond ) {
354                vec tmp ( Dt.cols() );
355                for ( int i = 0; i < Dt.cols(); i++ ) {
356                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
357                }
358                return tmp;
359        };
360        vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
361                vec tmp ( Dt.length() );
362                for ( int i = 0; i < Dt.length(); i++ ) {
363                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
364                }
365                return tmp;
366        };
367
368
369        //TODO smarter...
370        vec samplecond ( const vec &cond ) {
371                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
372                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
373                vec smpi;
374                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
375                for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
376                        // generate contribution of this pdf
377                        smpi = pdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                 
378                        // copy contribution of this pdf into smp
379                        dls ( i )->pushup ( smp, smpi );
380                }
381                return smp;
382        }
383
384        //! Load from structure with elements:
385        //!  \code
386        //! { class='mprod';
387        //!   pdfs = (..., ...);     // list of pdfs in the order of chain rule
388        //! }
389        //! \endcode
390        //!@}
391        void from_setting ( const Setting &set ) {
392                Array<shared_ptr<pdf> > atmp; //temporary Array
393                UI::get ( atmp, set, "pdfs", UI::compulsory );
394                set_elements ( atmp );
395        }
396};
397UIREGISTER ( mprod );
398SHAREDPTR ( mprod );
399
400//! Product of independent epdfs. For dependent pdfs, use mprod.
401class eprod: public epdf {
402protected:
403        //! Components (epdfs)
404        Array<const epdf*> epdfs;
405        //! Array of indeces
406        Array<datalink*> dls;
407public:
408        //! Default constructor
409        eprod () : epdfs ( 0 ), dls ( 0 ) {};
410        //! Set internal
411        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
412                epdfs = epdfs0;//.set_length ( epdfs0.length() );
413                dls.set_length ( epdfs.length() );
414
415                bool independent = true;
416                if ( named ) {
417                        for ( int i = 0; i < epdfs.length(); i++ ) {
418                                independent = rv.add ( epdfs ( i )->_rv() );
419                                bdm_assert_debug ( independent, "eprod:: given components are not independent." );
420                        }
421                        dim = rv._dsize();
422                } else {
423                        dim = 0;
424                        for ( int i = 0; i < epdfs.length(); i++ ) {
425                                dim += epdfs ( i )->dimension();
426                        }
427                }
428                //
429                int cumdim = 0;
430                int dimi = 0;
431                int i;
432                for ( i = 0; i < epdfs.length(); i++ ) {
433                        dls ( i ) = new datalink;
434                        if ( named ) {
435                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
436                        } else {
437                                dimi = epdfs ( i )->dimension();
438                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
439                                cumdim += dimi;
440                        }
441                }
442        }
443
444        vec mean() const {
445                vec tmp ( dim );
446                for ( int i = 0; i < epdfs.length(); i++ ) {
447                        vec pom = epdfs ( i )->mean();
448                        dls ( i )->pushup ( tmp, pom );
449                }
450                return tmp;
451        }
452        vec variance() const {
453                vec tmp ( dim ); //second moment
454                for ( int i = 0; i < epdfs.length(); i++ ) {
455                        vec pom = epdfs ( i )->mean();
456                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
457                }
458                return tmp - pow ( mean(), 2 );
459        }
460        vec sample() const {
461                vec tmp ( dim );
462                for ( int i = 0; i < epdfs.length(); i++ ) {
463                        vec pom = epdfs ( i )->sample();
464                        dls ( i )->pushup ( tmp, pom );
465                }
466                return tmp;
467        }
468        double evallog ( const vec &val ) const {
469                double tmp = 0;
470                for ( int i = 0; i < epdfs.length(); i++ ) {
471                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
472                }
473                bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
474                return tmp;
475        }
476        //!access function
477        const epdf* operator () ( int i ) const {
478                bdm_assert_debug ( i < epdfs.length(), "wrong index" );
479                return epdfs ( i );
480        }
481
482        //!Destructor
483        ~eprod() {
484                for ( int i = 0; i < epdfs.length(); i++ ) {
485                        delete dls ( i );
486                }
487        }
488};
489
490
491/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
492
493*/
494class mmix : public pdf {
495protected:
496        //! Component (pdfs)
497        Array<shared_ptr<pdf> > Coms;
498        //!weights of the components
499        vec w;
500public:
501        //!Default constructor
502        mmix() : Coms(0) { }
503
504        //! Set weights \c w and components \c R
505        void set_parameters ( const vec &w0, const Array<shared_ptr<pdf> > &Coms0 ) {
506                //!\todo check if all components are OK
507                Coms = Coms0;
508                w=w0;   
509
510                if (Coms0.length()>0){
511                        set_rv(Coms(0)->_rv());
512                        dim = rv._dsize();
513                        set_rvc(Coms(0)->_rvc());
514                        dimc = rvc._dsize();
515                }
516        }
517        double evallogcond (const vec &dt, const vec &cond) {
518                double ll=0.0;
519                for (int i=0;i<Coms.length();i++){
520                        ll+=Coms(i)->evallogcond(dt,cond);
521                }
522                return ll;
523        }
524
525        vec samplecond (const vec &cond);
526
527        //! Load from structure with elements:
528        //!  \code
529        //! { class='mmix';
530        //!   pdfs = (..., ...);     // list of pdfs in the mixture
531        //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
532        //! }
533        //! \endcode
534        //!@}
535        void from_setting ( const Setting &set ) {
536                UI::get ( Coms, set, "pdfs", UI::compulsory ); 
537
538                // TODO ma byt zde, ci ve validate()?
539                if (Coms.length()>0){
540                        set_rv(Coms(0)->_rv());
541                        dim = rv._dsize();
542                        set_rvc(Coms(0)->_rvc());
543                        dimc = rvc._dsize();
544                }
545
546                if( !UI::get( w, set, "weights", UI::optional ) )
547                {
548                        int len = Coms.length();
549                        w.set_length( len );
550                        w = 1.0 / len;
551                }
552        }
553};
554SHAREDPTR( mmix );
555UIREGISTER ( mmix );
556
557}
558#endif //MX_H
Note: See TracBrowser for help on using the browser.