root/library/bdm/stat/emix.h @ 477

Revision 477, 12.6 kB (checked in by mido, 15 years ago)

panove, vite, jak jsem peclivej na upravu kodu.. snad se vam bude libit:) konfigurace je v souboru /system/astylerc

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41        //!Denominator in the form of epdf
42        epdf* den;
43        //!flag for destructor
44        bool destroynom;
45        //!datalink between conditional and nom
46        datalink_m2e dl;
47public:
48        //!Default constructor. By default, the given epdf is not copied!
49        //! It is assumed that this function will be used only temporarily.
50        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ) {
51                // adjust rv and rvc
52                rvc = nom0->_rv().subt ( rv );
53                dimc = rvc._dsize();
54                set_ep ( shared_ptr<epdf> ( new epdf ) );
55                e()->set_parameters ( rv._dsize() );
56                e()->set_rv ( rv );
57
58                //prepare data structures
59                if ( copy ) {
60                        it_error ( "todo" );
61                        destroynom = true;
62                } else {
63                        nom = nom0;
64                        destroynom = false;
65                }
66                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
67
68                // build denominator
69                den = nom->marginal ( rvc );
70                dl.set_connection ( rv, rvc, nom0->_rv() );
71        };
72        double evallogcond ( const vec &val, const vec &cond ) {
73                double tmp;
74                vec nom_val ( e()->dimension() + dimc );
75                dl.pushup_cond ( nom_val, val, cond );
76                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
77                it_assert_debug ( std::isfinite ( tmp ), "Infinite value" );
78                return tmp;
79        }
80        //! Object takes ownership of nom and will destroy it
81        void ownnom() {
82                destroynom = true;
83        }
84        //! Default destructor
85        ~mratio() {
86                delete den;
87                if ( destroynom ) {
88                        delete nom;
89                }
90        }
91};
92
93/*!
94* \brief Mixture of epdfs
95
96Density function:
97\f[
98f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
99\f]
100where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
101
102*/
103class emix : public epdf {
104protected:
105        //! weights of the components
106        vec w;
107        //! Component (epdfs)
108        Array<epdf*> Coms;
109        //!Flag if owning Coms
110        bool destroyComs;
111public:
112        //!Default constructor
113        emix ( ) : epdf ( ) {};
114        //! Set weights \c w and components \c Coms
115        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
116        void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy = false );
117
118        vec sample() const;
119        vec mean() const {
120                int i;
121                vec mu = zeros ( dim );
122                for ( i = 0; i < w.length(); i++ ) {
123                        mu += w ( i ) * Coms ( i )->mean();
124                }
125                return mu;
126        }
127        vec variance() const {
128                //non-central moment
129                vec mom2 = zeros ( dim );
130                for ( int i = 0; i < w.length(); i++ ) {
131                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
132                }
133                //central moment
134                return mom2 - pow ( mean(), 2 );
135        }
136        double evallog ( const vec &val ) const {
137                int i;
138                double sum = 0.0;
139                for ( i = 0; i < w.length(); i++ ) {
140                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
141                }
142                if ( sum == 0.0 ) {
143                        sum = std::numeric_limits<double>::epsilon();
144                }
145                double tmp = log ( sum );
146                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
147                return tmp;
148        };
149        vec evallog_m ( const mat &Val ) const {
150                vec x = zeros ( Val.cols() );
151                for ( int i = 0; i < w.length(); i++ ) {
152                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
153                }
154                return log ( x );
155        };
156        //! Auxiliary function that returns pdflog for each component
157        mat evallog_M ( const mat &Val ) const {
158                mat X ( w.length(), Val.cols() );
159                for ( int i = 0; i < w.length(); i++ ) {
160                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
161                }
162                return X;
163        };
164
165        emix* marginal ( const RV &rv ) const;
166        mratio* condition ( const RV &rv ) const; //why not mratio!!
167
168//Access methods
169        //! returns a pointer to the internal mean value. Use with Care!
170        vec& _w() {
171                return w;
172        }
173        virtual ~emix() {
174                if ( destroyComs ) {
175                        for ( int i = 0; i < Coms.length(); i++ ) {
176                                delete Coms ( i );
177                        }
178                }
179        }
180        //! Auxiliary function for taking ownership of the Coms()
181        void ownComs() {
182                destroyComs = true;
183        }
184
185        //!access function
186        epdf* _Coms ( int i ) {
187                return Coms ( i );
188        }
189        void set_rv ( const RV &rv ) {
190                epdf::set_rv ( rv );
191                for ( int i = 0; i < Coms.length(); i++ ) {
192                        Coms ( i )->set_rv ( rv );
193                }
194        }
195};
196
197
198/*!
199* \brief Mixture of egiws
200
201*/
202class egiwmix : public egiw {
203protected:
204        //! weights of the components
205        vec w;
206        //! Component (epdfs)
207        Array<egiw*> Coms;
208        //!Flag if owning Coms
209        bool destroyComs;
210public:
211        //!Default constructor
212        egiwmix ( ) : egiw ( ) {};
213
214        //! Set weights \c w and components \c Coms
215        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
216        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
217
218        //!return expected value
219        vec mean() const;
220
221        //!return a sample from the density
222        vec sample() const;
223
224        //!return the expected variance
225        vec variance() const;
226
227        // TODO!!! Defined to follow ANSI and/or for future development
228        void mean_mat ( mat &M, mat&R ) const {};
229        double evallog_nn ( const vec &val ) const {
230                return 0;
231        };
232        double lognc () const {
233                return 0;
234        };
235        emix* marginal ( const RV &rv ) const;
236
237//Access methods
238        //! returns a pointer to the internal mean value. Use with Care!
239        vec& _w() {
240                return w;
241        }
242        virtual ~egiwmix() {
243                if ( destroyComs ) {
244                        for ( int i = 0; i < Coms.length(); i++ ) {
245                                delete Coms ( i );
246                        }
247                }
248        }
249        //! Auxiliary function for taking ownership of the Coms()
250        void ownComs() {
251                destroyComs = true;
252        }
253
254        //!access function
255        egiw* _Coms ( int i ) {
256                return Coms ( i );
257        }
258
259        void set_rv ( const RV &rv ) {
260                egiw::set_rv ( rv );
261                for ( int i = 0; i < Coms.length(); i++ ) {
262                        Coms ( i )->set_rv ( rv );
263                }
264        }
265
266        //! Approximation of a GiW mix by a single GiW pdf
267        egiw* approx();
268};
269
270/*! \brief Chain rule decomposition of epdf
271
272Probability density in the form of Chain-rule decomposition:
273\[
274f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
275\]
276Note that
277*/
278class mprod: public compositepdf, public mpdf {
279protected:
280        //! pointers to epdfs - shortcut to mpdfs().posterior()
281        Array<epdf*> epdfs;
282        //! Data link for each mpdfs
283        Array<datalink_m2m*> dls;
284
285        //! dummy ep
286        shared_ptr<epdf> dummy;
287
288public:
289        /*!\brief Constructor from list of mFacs,
290        */
291        mprod() : dummy ( new epdf() ) { }
292        mprod ( Array<mpdf*> mFacs ) :
293                        dummy ( new epdf() ) {
294                set_elements ( mFacs );
295        }
296
297        void set_elements ( Array<mpdf*> mFacs , bool own = false ) {
298
299                compositepdf::set_elements ( mFacs, own );
300                dls.set_size ( mFacs.length() );
301                epdfs.set_size ( mFacs.length() );
302
303                set_ep ( dummy );
304                RV rv = getrv ( true );
305                set_rv ( rv );
306                dummy->set_parameters ( rv._dsize() );
307                setrvc ( e()->_rv(), rvc );
308                // rv and rvc established = > we can link them with mpdfs
309                for ( int i = 0; i < mpdfs.length(); i++ ) {
310                        dls ( i ) = new datalink_m2m;
311                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
312                }
313
314                for ( int i = 0; i < mpdfs.length(); i++ ) {
315                        epdfs ( i ) = mpdfs ( i )->e();
316                }
317        };
318
319        double evallogcond ( const vec &val, const vec &cond ) {
320                int i;
321                double res = 0.0;
322                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
323                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
324                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
325                                                }
326                                                // add logarithms
327                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
328                        res += mpdfs ( i )->evallogcond (
329                                   dls ( i )->pushdown ( val ),
330                                   dls ( i )->get_cond ( val, cond )
331                               );
332                }
333                return res;
334        }
335        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
336                vec tmp ( Dt.cols() );
337                for ( int i = 0; i < Dt.cols(); i++ ) {
338                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
339                }
340                return tmp;
341        };
342        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
343                vec tmp ( Dt.length() );
344                for ( int i = 0; i < Dt.length(); i++ ) {
345                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
346                }
347                return tmp;
348        };
349
350
351        //TODO smarter...
352        vec samplecond ( const vec &cond ) {
353                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
354                vec smp = std::numeric_limits<double>::infinity() * ones ( e()->dimension() );
355                vec smpi;
356                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
357                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
358                        if ( mpdfs ( i )->dimensionc() ) {
359                                mpdfs ( i )->condition ( dls ( i )->get_cond ( smp , cond ) ); // smp is val here!!
360                        }
361                        smpi = epdfs ( i )->sample();
362                        // copy contribution of this pdf into smp
363                        dls ( i )->pushup ( smp, smpi );
364                        // add ith likelihood contribution
365                }
366                return smp;
367        }
368        mat samplecond ( const vec &cond,  int N ) {
369                mat Smp ( dimension(), N );
370                for ( int i = 0; i < N; i++ ) {
371                        Smp.set_col ( i, samplecond ( cond ) );
372                }
373                return Smp;
374        }
375
376        ~mprod() {};
377        //! Load from structure with elements:
378        //!  \code
379        //! { class='mprod';
380        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
381        //! }
382        //! \endcode
383        //!@}
384        void from_setting ( const Setting &set ) {
385                Array<mpdf*> Atmp; //temporary Array
386                UI::get ( Atmp, set, "mpdfs", UI::compulsory );
387                set_elements ( Atmp, true );
388        }
389
390};
391UIREGISTER ( mprod );
392
393//! Product of independent epdfs. For dependent pdfs, use mprod.
394class eprod: public epdf {
395protected:
396        //! Components (epdfs)
397        Array<const epdf*> epdfs;
398        //! Array of indeces
399        Array<datalink*> dls;
400public:
401        eprod () : epdfs ( 0 ), dls ( 0 ) {};
402        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
403                epdfs = epdfs0;//.set_length ( epdfs0.length() );
404                dls.set_length ( epdfs.length() );
405
406                bool independent = true;
407                if ( named ) {
408                        for ( int i = 0; i < epdfs.length(); i++ ) {
409                                independent = rv.add ( epdfs ( i )->_rv() );
410                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
411                        }
412                        dim = rv._dsize();
413                } else {
414                        dim = 0;
415                        for ( int i = 0; i < epdfs.length(); i++ ) {
416                                dim += epdfs ( i )->dimension();
417                        }
418                }
419                //
420                int cumdim = 0;
421                int dimi = 0;
422                int i;
423                for ( i = 0; i < epdfs.length(); i++ ) {
424                        dls ( i ) = new datalink;
425                        if ( named ) {
426                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
427                        } else {
428                                dimi = epdfs ( i )->dimension();
429                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
430                                cumdim += dimi;
431                        }
432                }
433        }
434
435        vec mean() const {
436                vec tmp ( dim );
437                for ( int i = 0; i < epdfs.length(); i++ ) {
438                        vec pom = epdfs ( i )->mean();
439                        dls ( i )->pushup ( tmp, pom );
440                }
441                return tmp;
442        }
443        vec variance() const {
444                vec tmp ( dim ); //second moment
445                for ( int i = 0; i < epdfs.length(); i++ ) {
446                        vec pom = epdfs ( i )->mean();
447                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
448                }
449                return tmp - pow ( mean(), 2 );
450        }
451        vec sample() const {
452                vec tmp ( dim );
453                for ( int i = 0; i < epdfs.length(); i++ ) {
454                        vec pom = epdfs ( i )->sample();
455                        dls ( i )->pushup ( tmp, pom );
456                }
457                return tmp;
458        }
459        double evallog ( const vec &val ) const {
460                double tmp = 0;
461                for ( int i = 0; i < epdfs.length(); i++ ) {
462                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
463                }
464                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
465                return tmp;
466        }
467        //!access function
468        const epdf* operator () ( int i ) const {
469                it_assert_debug ( i < epdfs.length(), "wrong index" );
470                return epdfs ( i );
471        }
472
473        //!Destructor
474        ~eprod() {
475                for ( int i = 0; i < epdfs.length(); i++ ) {
476                        delete dls ( i );
477                }
478        }
479};
480
481
482/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
483
484*/
485class mmix : public mpdf {
486protected:
487        //! Component (epdfs)
488        Array<mpdf*> Coms;
489
490        //!Internal epdf
491        shared_ptr<emix> iepdf;
492
493public:
494        //!Default constructor
495        mmix() : iepdf ( new emix() ) {
496                set_ep ( iepdf );
497        }
498
499        //! Set weights \c w and components \c R
500        void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
501                Array<epdf*> Eps ( Coms.length() );
502
503                for ( int i = 0; i < Coms.length(); i++ ) {
504                        Eps ( i ) = Coms ( i )->e();
505                }
506
507                iepdf->set_parameters ( w, Eps );
508        }
509
510        void condition ( const vec &cond ) {
511                for ( int i = 0; i < Coms.length(); i++ ) {
512                        Coms ( i )->condition ( cond );
513                }
514        };
515};
516
517}
518#endif //MX_H
Note: See TracBrowser for help on using the browser.