root/library/bdm/stat/emix.h @ 527

Revision 527, 12.0 kB (checked in by vbarta, 15 years ago)

using shared_ptr in UI (optionally so far; loading & saving Array<T *> still works but should be phased out); testsuite run leaks down from 8822 to 480 bytes

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93};
94
95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
105class emix : public epdf {
106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
110        Array<shared_ptr<epdf> > Coms;
111
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        void marginal ( const RV &rv, emix &target ) const;
168        shared_ptr<mpdf> condition ( const RV &rv ) const;
169
170//Access methods
171        //! returns a pointer to the internal mean value. Use with Care!
172        vec& _w() {
173                return w;
174        }
175
176        //!access function
177        shared_ptr<epdf> _Coms ( int i ) {
178                return Coms ( i );
179        }
180
181        void set_rv ( const RV &rv ) {
182                epdf::set_rv ( rv );
183                for ( int i = 0; i < Coms.length(); i++ ) {
184                        Coms ( i )->set_rv ( rv );
185                }
186        }
187};
188
189
190/*!
191* \brief Mixture of egiws
192
193*/
194class egiwmix : public egiw {
195protected:
196        //! weights of the components
197        vec w;
198        //! Component (epdfs)
199        Array<egiw*> Coms;
200        //!Flag if owning Coms
201        bool destroyComs;
202public:
203        //!Default constructor
204        egiwmix ( ) : egiw ( ) {};
205
206        //! Set weights \c w and components \c Coms
207        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
208        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
209
210        //!return expected value
211        vec mean() const;
212
213        //!return a sample from the density
214        vec sample() const;
215
216        //!return the expected variance
217        vec variance() const;
218
219        // TODO!!! Defined to follow ANSI and/or for future development
220        void mean_mat ( mat &M, mat&R ) const {};
221        double evallog_nn ( const vec &val ) const {
222                return 0;
223        };
224        double lognc () const {
225                return 0;
226        }
227
228        shared_ptr<epdf> marginal ( const RV &rv ) const;
229        void marginal ( const RV &rv, emix &target ) const;
230
231//Access methods
232        //! returns a pointer to the internal mean value. Use with Care!
233        vec& _w() {
234                return w;
235        }
236        virtual ~egiwmix() {
237                if ( destroyComs ) {
238                        for ( int i = 0; i < Coms.length(); i++ ) {
239                                delete Coms ( i );
240                        }
241                }
242        }
243        //! Auxiliary function for taking ownership of the Coms()
244        void ownComs() {
245                destroyComs = true;
246        }
247
248        //!access function
249        egiw* _Coms ( int i ) {
250                return Coms ( i );
251        }
252
253        void set_rv ( const RV &rv ) {
254                egiw::set_rv ( rv );
255                for ( int i = 0; i < Coms.length(); i++ ) {
256                        Coms ( i )->set_rv ( rv );
257                }
258        }
259
260        //! Approximation of a GiW mix by a single GiW pdf
261        egiw* approx();
262};
263
264/*! \brief Chain rule decomposition of epdf
265
266Probability density in the form of Chain-rule decomposition:
267\[
268f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
269\]
270Note that
271*/
272class mprod: public mpdf {
273private:
274        Array<shared_ptr<mpdf> > mpdfs;
275
276protected:
277        //! Data link for each mpdfs
278        Array<datalink_m2m*> dls;
279
280        //! dummy epdf used only as storage for RV and dim
281        epdf iepdf;
282
283public:
284        //! \brief Default constructor
285        mprod() { }
286
287        /*!\brief Constructor from list of mFacs
288        */
289        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
290                set_elements ( mFacs );
291        }
292
293        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
294
295        double evallogcond ( const vec &val, const vec &cond ) {
296                int i;
297                double res = 0.0;
298                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
299                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
300                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
301                                                }
302                                                // add logarithms
303                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
304                        res += mpdfs ( i )->evallogcond (
305                                   dls ( i )->pushdown ( val ),
306                                   dls ( i )->get_cond ( val, cond )
307                               );
308                }
309                return res;
310        }
311        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
312                vec tmp ( Dt.cols() );
313                for ( int i = 0; i < Dt.cols(); i++ ) {
314                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
315                }
316                return tmp;
317        };
318        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
319                vec tmp ( Dt.length() );
320                for ( int i = 0; i < Dt.length(); i++ ) {
321                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
322                }
323                return tmp;
324        };
325
326
327        //TODO smarter...
328        vec samplecond ( const vec &cond ) {
329                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
330                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
331                vec smpi;
332                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
333                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
334                        // generate contribution of this mpdf
335                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
336                        // copy contribution of this pdf into smp
337                        dls ( i )->pushup ( smp, smpi );
338                }
339                return smp;
340        }
341        mat samplecond ( const vec &cond,  int N ) {
342                mat Smp ( dimension(), N );
343                for ( int i = 0; i < N; i++ ) {
344                        Smp.set_col ( i, samplecond ( cond ) );
345                }
346                return Smp;
347        }
348
349        //! Load from structure with elements:
350        //!  \code
351        //! { class='mprod';
352        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
353        //! }
354        //! \endcode
355        //!@}
356        void from_setting ( const Setting &set ) {
357                Array<shared_ptr<mpdf> > atmp; //temporary Array
358                UI::get ( atmp, set, "mpdfs", UI::compulsory );
359                set_elements ( atmp );
360        }
361
362};
363UIREGISTER ( mprod );
364
365//! Product of independent epdfs. For dependent pdfs, use mprod.
366class eprod: public epdf {
367protected:
368        //! Components (epdfs)
369        Array<const epdf*> epdfs;
370        //! Array of indeces
371        Array<datalink*> dls;
372public:
373        eprod () : epdfs ( 0 ), dls ( 0 ) {};
374        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
375                epdfs = epdfs0;//.set_length ( epdfs0.length() );
376                dls.set_length ( epdfs.length() );
377
378                bool independent = true;
379                if ( named ) {
380                        for ( int i = 0; i < epdfs.length(); i++ ) {
381                                independent = rv.add ( epdfs ( i )->_rv() );
382                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
383                        }
384                        dim = rv._dsize();
385                } else {
386                        dim = 0;
387                        for ( int i = 0; i < epdfs.length(); i++ ) {
388                                dim += epdfs ( i )->dimension();
389                        }
390                }
391                //
392                int cumdim = 0;
393                int dimi = 0;
394                int i;
395                for ( i = 0; i < epdfs.length(); i++ ) {
396                        dls ( i ) = new datalink;
397                        if ( named ) {
398                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
399                        } else {
400                                dimi = epdfs ( i )->dimension();
401                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
402                                cumdim += dimi;
403                        }
404                }
405        }
406
407        vec mean() const {
408                vec tmp ( dim );
409                for ( int i = 0; i < epdfs.length(); i++ ) {
410                        vec pom = epdfs ( i )->mean();
411                        dls ( i )->pushup ( tmp, pom );
412                }
413                return tmp;
414        }
415        vec variance() const {
416                vec tmp ( dim ); //second moment
417                for ( int i = 0; i < epdfs.length(); i++ ) {
418                        vec pom = epdfs ( i )->mean();
419                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
420                }
421                return tmp - pow ( mean(), 2 );
422        }
423        vec sample() const {
424                vec tmp ( dim );
425                for ( int i = 0; i < epdfs.length(); i++ ) {
426                        vec pom = epdfs ( i )->sample();
427                        dls ( i )->pushup ( tmp, pom );
428                }
429                return tmp;
430        }
431        double evallog ( const vec &val ) const {
432                double tmp = 0;
433                for ( int i = 0; i < epdfs.length(); i++ ) {
434                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
435                }
436                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
437                return tmp;
438        }
439        //!access function
440        const epdf* operator () ( int i ) const {
441                it_assert_debug ( i < epdfs.length(), "wrong index" );
442                return epdfs ( i );
443        }
444
445        //!Destructor
446        ~eprod() {
447                for ( int i = 0; i < epdfs.length(); i++ ) {
448                        delete dls ( i );
449                }
450        }
451};
452
453
454/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
455
456*/
457class mmix : public mpdf {
458protected:
459        //! Component (mpdfs)
460        Array<shared_ptr<mpdf> > Coms;
461        //!weights of the components
462        vec w;
463        //! dummy epdfs
464        epdf dummy_epdf;
465public:
466        //!Default constructor
467        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
468
469        //! Set weights \c w and components \c R
470        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
471                //!\TODO check if all components are OK
472                Coms = Coms0;
473                w=w0;   
474
475                if (Coms0.length()>0){
476                        set_rv(Coms(0)->_rv());
477                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
478                        set_rvc(Coms(0)->_rvc());
479                        dimc = rvc._dsize();
480                }
481        }
482        double evallogcond (const vec &dt, const vec &cond) {
483                double ll=0.0;
484                for (int i=0;i<Coms.length();i++){
485                        ll+=Coms(i)->evallogcond(dt,cond);
486                }
487                return ll;
488        }
489
490        vec samplecond (const vec &cond);
491
492};
493
494}
495#endif //MX_H
Note: See TracBrowser for help on using the browser.