root/library/bdm/stat/emix.h @ 550

Revision 550, 12.1 kB (checked in by vbarta, 15 years ago)

disabled implicit copy constructor & assignement operator of mratio

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93
94private:
95        // not implemented
96        mratio ( const mratio & );
97        mratio &operator=( const mratio & );
98};
99
100/*!
101* \brief Mixture of epdfs
102
103Density function:
104\f[
105f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
106\f]
107where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
108
109*/
110class emix : public epdf {
111protected:
112        //! weights of the components
113        vec w;
114        //! Component (epdfs)
115        Array<shared_ptr<epdf> > Coms;
116
117public:
118        //!Default constructor
119        emix ( ) : epdf ( ) {};
120        //! Set weights \c w and components \c Coms
121        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
122        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
123
124        vec sample() const;
125        vec mean() const {
126                int i;
127                vec mu = zeros ( dim );
128                for ( i = 0; i < w.length(); i++ ) {
129                        mu += w ( i ) * Coms ( i )->mean();
130                }
131                return mu;
132        }
133        vec variance() const {
134                //non-central moment
135                vec mom2 = zeros ( dim );
136                for ( int i = 0; i < w.length(); i++ ) {
137                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
138                }
139                //central moment
140                return mom2 - pow ( mean(), 2 );
141        }
142        double evallog ( const vec &val ) const {
143                int i;
144                double sum = 0.0;
145                for ( i = 0; i < w.length(); i++ ) {
146                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
147                }
148                if ( sum == 0.0 ) {
149                        sum = std::numeric_limits<double>::epsilon();
150                }
151                double tmp = log ( sum );
152                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
153                return tmp;
154        };
155        vec evallog_m ( const mat &Val ) const {
156                vec x = zeros ( Val.cols() );
157                for ( int i = 0; i < w.length(); i++ ) {
158                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
159                }
160                return log ( x );
161        };
162        //! Auxiliary function that returns pdflog for each component
163        mat evallog_M ( const mat &Val ) const {
164                mat X ( w.length(), Val.cols() );
165                for ( int i = 0; i < w.length(); i++ ) {
166                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
167                }
168                return X;
169        };
170
171        shared_ptr<epdf> marginal ( const RV &rv ) const;
172        //! Update already existing marginal density  \c target
173        void marginal ( const RV &rv, emix &target ) const;
174        shared_ptr<mpdf> condition ( const RV &rv ) const;
175
176//Access methods
177        //! returns a pointer to the internal mean value. Use with Care!
178        vec& _w() {
179                return w;
180        }
181
182        //!access function
183        shared_ptr<epdf> _Coms ( int i ) {
184                return Coms ( i );
185        }
186
187        void set_rv ( const RV &rv ) {
188                epdf::set_rv ( rv );
189                for ( int i = 0; i < Coms.length(); i++ ) {
190                        Coms ( i )->set_rv ( rv );
191                }
192        }
193};
194SHAREDPTR( emix );
195
196/*!
197* \brief Mixture of egiws
198
199*/
200class egiwmix : public egiw {
201protected:
202        //! weights of the components
203        vec w;
204        //! Component (epdfs)
205        Array<egiw*> Coms;
206        //!Flag if owning Coms
207        bool destroyComs;
208public:
209        //!Default constructor
210        egiwmix ( ) : egiw ( ) {};
211
212        //! Set weights \c w and components \c Coms
213        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
214        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
215
216        //!return expected value
217        vec mean() const;
218
219        //!return a sample from the density
220        vec sample() const;
221
222        //!return the expected variance
223        vec variance() const;
224
225        // TODO!!! Defined to follow ANSI and/or for future development
226        void mean_mat ( mat &M, mat&R ) const {};
227        double evallog_nn ( const vec &val ) const {
228                return 0;
229        };
230        double lognc () const {
231                return 0;
232        }
233
234        shared_ptr<epdf> marginal ( const RV &rv ) const;
235        void marginal ( const RV &rv, emix &target ) const;
236
237//Access methods
238        //! returns a pointer to the internal mean value. Use with Care!
239        vec& _w() {
240                return w;
241        }
242        virtual ~egiwmix() {
243                if ( destroyComs ) {
244                        for ( int i = 0; i < Coms.length(); i++ ) {
245                                delete Coms ( i );
246                        }
247                }
248        }
249        //! Auxiliary function for taking ownership of the Coms()
250        void ownComs() {
251                destroyComs = true;
252        }
253
254        //!access function
255        egiw* _Coms ( int i ) {
256                return Coms ( i );
257        }
258
259        void set_rv ( const RV &rv ) {
260                egiw::set_rv ( rv );
261                for ( int i = 0; i < Coms.length(); i++ ) {
262                        Coms ( i )->set_rv ( rv );
263                }
264        }
265
266        //! Approximation of a GiW mix by a single GiW pdf
267        egiw* approx();
268};
269
270/*! \brief Chain rule decomposition of epdf
271
272Probability density in the form of Chain-rule decomposition:
273\[
274f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
275\]
276Note that
277*/
278class mprod: public mpdf {
279private:
280        Array<shared_ptr<mpdf> > mpdfs;
281
282        //! Data link for each mpdfs
283        Array<shared_ptr<datalink_m2m> > dls;
284
285protected:
286        //! dummy epdf used only as storage for RV and dim
287        epdf iepdf;
288
289public:
290        //! \brief Default constructor
291        mprod() { }
292
293        /*!\brief Constructor from list of mFacs
294        */
295        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
296                set_elements ( mFacs );
297        }
298        //! Set internal \c mpdfs from given values
299        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
300
301        double evallogcond ( const vec &val, const vec &cond ) {
302                int i;
303                double res = 0.0;
304                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
305                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
306                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
307                                                }
308                                                // add logarithms
309                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
310                        res += mpdfs ( i )->evallogcond (
311                                   dls ( i )->pushdown ( val ),
312                                   dls ( i )->get_cond ( val, cond )
313                               );
314                }
315                return res;
316        }
317        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
318                vec tmp ( Dt.cols() );
319                for ( int i = 0; i < Dt.cols(); i++ ) {
320                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
321                }
322                return tmp;
323        };
324        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
325                vec tmp ( Dt.length() );
326                for ( int i = 0; i < Dt.length(); i++ ) {
327                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
328                }
329                return tmp;
330        };
331
332
333        //TODO smarter...
334        vec samplecond ( const vec &cond ) {
335                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
336                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
337                vec smpi;
338                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
339                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
340                        // generate contribution of this mpdf
341                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
342                        // copy contribution of this pdf into smp
343                        dls ( i )->pushup ( smp, smpi );
344                }
345                return smp;
346        }
347
348        //! Load from structure with elements:
349        //!  \code
350        //! { class='mprod';
351        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
352        //! }
353        //! \endcode
354        //!@}
355        void from_setting ( const Setting &set ) {
356                Array<shared_ptr<mpdf> > atmp; //temporary Array
357                UI::get ( atmp, set, "mpdfs", UI::compulsory );
358                set_elements ( atmp );
359        }
360};
361UIREGISTER ( mprod );
362SHAREDPTR ( mprod );
363
364//! Product of independent epdfs. For dependent pdfs, use mprod.
365class eprod: public epdf {
366protected:
367        //! Components (epdfs)
368        Array<const epdf*> epdfs;
369        //! Array of indeces
370        Array<datalink*> dls;
371public:
372        //! Default constructor
373        eprod () : epdfs ( 0 ), dls ( 0 ) {};
374        //! Set internal
375        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
376                epdfs = epdfs0;//.set_length ( epdfs0.length() );
377                dls.set_length ( epdfs.length() );
378
379                bool independent = true;
380                if ( named ) {
381                        for ( int i = 0; i < epdfs.length(); i++ ) {
382                                independent = rv.add ( epdfs ( i )->_rv() );
383                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
384                        }
385                        dim = rv._dsize();
386                } else {
387                        dim = 0;
388                        for ( int i = 0; i < epdfs.length(); i++ ) {
389                                dim += epdfs ( i )->dimension();
390                        }
391                }
392                //
393                int cumdim = 0;
394                int dimi = 0;
395                int i;
396                for ( i = 0; i < epdfs.length(); i++ ) {
397                        dls ( i ) = new datalink;
398                        if ( named ) {
399                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
400                        } else {
401                                dimi = epdfs ( i )->dimension();
402                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
403                                cumdim += dimi;
404                        }
405                }
406        }
407
408        vec mean() const {
409                vec tmp ( dim );
410                for ( int i = 0; i < epdfs.length(); i++ ) {
411                        vec pom = epdfs ( i )->mean();
412                        dls ( i )->pushup ( tmp, pom );
413                }
414                return tmp;
415        }
416        vec variance() const {
417                vec tmp ( dim ); //second moment
418                for ( int i = 0; i < epdfs.length(); i++ ) {
419                        vec pom = epdfs ( i )->mean();
420                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
421                }
422                return tmp - pow ( mean(), 2 );
423        }
424        vec sample() const {
425                vec tmp ( dim );
426                for ( int i = 0; i < epdfs.length(); i++ ) {
427                        vec pom = epdfs ( i )->sample();
428                        dls ( i )->pushup ( tmp, pom );
429                }
430                return tmp;
431        }
432        double evallog ( const vec &val ) const {
433                double tmp = 0;
434                for ( int i = 0; i < epdfs.length(); i++ ) {
435                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
436                }
437                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
438                return tmp;
439        }
440        //!access function
441        const epdf* operator () ( int i ) const {
442                it_assert_debug ( i < epdfs.length(), "wrong index" );
443                return epdfs ( i );
444        }
445
446        //!Destructor
447        ~eprod() {
448                for ( int i = 0; i < epdfs.length(); i++ ) {
449                        delete dls ( i );
450                }
451        }
452};
453
454
455/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
456
457*/
458class mmix : public mpdf {
459protected:
460        //! Component (mpdfs)
461        Array<shared_ptr<mpdf> > Coms;
462        //!weights of the components
463        vec w;
464        //! dummy epdfs
465        epdf dummy_epdf;
466public:
467        //!Default constructor
468        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
469
470        //! Set weights \c w and components \c R
471        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
472                //!\todo check if all components are OK
473                Coms = Coms0;
474                w=w0;   
475
476                if (Coms0.length()>0){
477                        set_rv(Coms(0)->_rv());
478                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
479                        set_rvc(Coms(0)->_rvc());
480                        dimc = rvc._dsize();
481                }
482        }
483        double evallogcond (const vec &dt, const vec &cond) {
484                double ll=0.0;
485                for (int i=0;i<Coms.length();i++){
486                        ll+=Coms(i)->evallogcond(dt,cond);
487                }
488                return ll;
489        }
490
491        vec samplecond (const vec &cond);
492
493};
494
495}
496#endif //MX_H
Note: See TracBrowser for help on using the browser.