root/library/bdm/stat/emix.h @ 546

Revision 546, 12.0 kB (checked in by vbarta, 15 years ago)

using shared_ptr for mprod's datalinks; testsuite now without memory leaks

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93};
94
95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
105class emix : public epdf {
106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
110        Array<shared_ptr<epdf> > Coms;
111
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        //! Update already existing marginal density  \c target
168        void marginal ( const RV &rv, emix &target ) const;
169        shared_ptr<mpdf> condition ( const RV &rv ) const;
170
171//Access methods
172        //! returns a pointer to the internal mean value. Use with Care!
173        vec& _w() {
174                return w;
175        }
176
177        //!access function
178        shared_ptr<epdf> _Coms ( int i ) {
179                return Coms ( i );
180        }
181
182        void set_rv ( const RV &rv ) {
183                epdf::set_rv ( rv );
184                for ( int i = 0; i < Coms.length(); i++ ) {
185                        Coms ( i )->set_rv ( rv );
186                }
187        }
188};
189SHAREDPTR( emix );
190
191/*!
192* \brief Mixture of egiws
193
194*/
195class egiwmix : public egiw {
196protected:
197        //! weights of the components
198        vec w;
199        //! Component (epdfs)
200        Array<egiw*> Coms;
201        //!Flag if owning Coms
202        bool destroyComs;
203public:
204        //!Default constructor
205        egiwmix ( ) : egiw ( ) {};
206
207        //! Set weights \c w and components \c Coms
208        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
209        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
210
211        //!return expected value
212        vec mean() const;
213
214        //!return a sample from the density
215        vec sample() const;
216
217        //!return the expected variance
218        vec variance() const;
219
220        // TODO!!! Defined to follow ANSI and/or for future development
221        void mean_mat ( mat &M, mat&R ) const {};
222        double evallog_nn ( const vec &val ) const {
223                return 0;
224        };
225        double lognc () const {
226                return 0;
227        }
228
229        shared_ptr<epdf> marginal ( const RV &rv ) const;
230        void marginal ( const RV &rv, emix &target ) const;
231
232//Access methods
233        //! returns a pointer to the internal mean value. Use with Care!
234        vec& _w() {
235                return w;
236        }
237        virtual ~egiwmix() {
238                if ( destroyComs ) {
239                        for ( int i = 0; i < Coms.length(); i++ ) {
240                                delete Coms ( i );
241                        }
242                }
243        }
244        //! Auxiliary function for taking ownership of the Coms()
245        void ownComs() {
246                destroyComs = true;
247        }
248
249        //!access function
250        egiw* _Coms ( int i ) {
251                return Coms ( i );
252        }
253
254        void set_rv ( const RV &rv ) {
255                egiw::set_rv ( rv );
256                for ( int i = 0; i < Coms.length(); i++ ) {
257                        Coms ( i )->set_rv ( rv );
258                }
259        }
260
261        //! Approximation of a GiW mix by a single GiW pdf
262        egiw* approx();
263};
264
265/*! \brief Chain rule decomposition of epdf
266
267Probability density in the form of Chain-rule decomposition:
268\[
269f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
270\]
271Note that
272*/
273class mprod: public mpdf {
274private:
275        Array<shared_ptr<mpdf> > mpdfs;
276
277        //! Data link for each mpdfs
278        Array<shared_ptr<datalink_m2m> > dls;
279
280protected:
281        //! dummy epdf used only as storage for RV and dim
282        epdf iepdf;
283
284public:
285        //! \brief Default constructor
286        mprod() { }
287
288        /*!\brief Constructor from list of mFacs
289        */
290        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
291                set_elements ( mFacs );
292        }
293        //! Set internal \c mpdfs from given values
294        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
295
296        double evallogcond ( const vec &val, const vec &cond ) {
297                int i;
298                double res = 0.0;
299                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
300                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
301                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
302                                                }
303                                                // add logarithms
304                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
305                        res += mpdfs ( i )->evallogcond (
306                                   dls ( i )->pushdown ( val ),
307                                   dls ( i )->get_cond ( val, cond )
308                               );
309                }
310                return res;
311        }
312        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
313                vec tmp ( Dt.cols() );
314                for ( int i = 0; i < Dt.cols(); i++ ) {
315                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
316                }
317                return tmp;
318        };
319        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
320                vec tmp ( Dt.length() );
321                for ( int i = 0; i < Dt.length(); i++ ) {
322                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
323                }
324                return tmp;
325        };
326
327
328        //TODO smarter...
329        vec samplecond ( const vec &cond ) {
330                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
331                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
332                vec smpi;
333                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
334                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
335                        // generate contribution of this mpdf
336                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
337                        // copy contribution of this pdf into smp
338                        dls ( i )->pushup ( smp, smpi );
339                }
340                return smp;
341        }
342
343        //! Load from structure with elements:
344        //!  \code
345        //! { class='mprod';
346        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
347        //! }
348        //! \endcode
349        //!@}
350        void from_setting ( const Setting &set ) {
351                Array<shared_ptr<mpdf> > atmp; //temporary Array
352                UI::get ( atmp, set, "mpdfs", UI::compulsory );
353                set_elements ( atmp );
354        }
355
356};
357UIREGISTER ( mprod );
358SHAREDPTR ( mprod );
359
360//! Product of independent epdfs. For dependent pdfs, use mprod.
361class eprod: public epdf {
362protected:
363        //! Components (epdfs)
364        Array<const epdf*> epdfs;
365        //! Array of indeces
366        Array<datalink*> dls;
367public:
368        //! Default constructor
369        eprod () : epdfs ( 0 ), dls ( 0 ) {};
370        //! Set internal
371        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
372                epdfs = epdfs0;//.set_length ( epdfs0.length() );
373                dls.set_length ( epdfs.length() );
374
375                bool independent = true;
376                if ( named ) {
377                        for ( int i = 0; i < epdfs.length(); i++ ) {
378                                independent = rv.add ( epdfs ( i )->_rv() );
379                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
380                        }
381                        dim = rv._dsize();
382                } else {
383                        dim = 0;
384                        for ( int i = 0; i < epdfs.length(); i++ ) {
385                                dim += epdfs ( i )->dimension();
386                        }
387                }
388                //
389                int cumdim = 0;
390                int dimi = 0;
391                int i;
392                for ( i = 0; i < epdfs.length(); i++ ) {
393                        dls ( i ) = new datalink;
394                        if ( named ) {
395                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
396                        } else {
397                                dimi = epdfs ( i )->dimension();
398                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
399                                cumdim += dimi;
400                        }
401                }
402        }
403
404        vec mean() const {
405                vec tmp ( dim );
406                for ( int i = 0; i < epdfs.length(); i++ ) {
407                        vec pom = epdfs ( i )->mean();
408                        dls ( i )->pushup ( tmp, pom );
409                }
410                return tmp;
411        }
412        vec variance() const {
413                vec tmp ( dim ); //second moment
414                for ( int i = 0; i < epdfs.length(); i++ ) {
415                        vec pom = epdfs ( i )->mean();
416                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
417                }
418                return tmp - pow ( mean(), 2 );
419        }
420        vec sample() const {
421                vec tmp ( dim );
422                for ( int i = 0; i < epdfs.length(); i++ ) {
423                        vec pom = epdfs ( i )->sample();
424                        dls ( i )->pushup ( tmp, pom );
425                }
426                return tmp;
427        }
428        double evallog ( const vec &val ) const {
429                double tmp = 0;
430                for ( int i = 0; i < epdfs.length(); i++ ) {
431                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
432                }
433                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
434                return tmp;
435        }
436        //!access function
437        const epdf* operator () ( int i ) const {
438                it_assert_debug ( i < epdfs.length(), "wrong index" );
439                return epdfs ( i );
440        }
441
442        //!Destructor
443        ~eprod() {
444                for ( int i = 0; i < epdfs.length(); i++ ) {
445                        delete dls ( i );
446                }
447        }
448};
449
450
451/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
452
453*/
454class mmix : public mpdf {
455protected:
456        //! Component (mpdfs)
457        Array<shared_ptr<mpdf> > Coms;
458        //!weights of the components
459        vec w;
460        //! dummy epdfs
461        epdf dummy_epdf;
462public:
463        //!Default constructor
464        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
465
466        //! Set weights \c w and components \c R
467        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
468                //!\todo check if all components are OK
469                Coms = Coms0;
470                w=w0;   
471
472                if (Coms0.length()>0){
473                        set_rv(Coms(0)->_rv());
474                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
475                        set_rvc(Coms(0)->_rvc());
476                        dimc = rvc._dsize();
477                }
478        }
479        double evallogcond (const vec &dt, const vec &cond) {
480                double ll=0.0;
481                for (int i=0;i<Coms.length();i++){
482                        ll+=Coms(i)->evallogcond(dt,cond);
483                }
484                return ll;
485        }
486
487        vec samplecond (const vec &cond);
488
489};
490
491}
492#endif //MX_H
Note: See TracBrowser for help on using the browser.