root/library/bdm/stat/emix.h @ 549

Revision 549, 12.1 kB (checked in by vbarta, 15 years ago)

disabled implicit copy constructor & assignement operator of mprod

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public mpdf {
38protected:
39        //! Nominator in the form of mpdf
40        const epdf* nom;
41
42        //!Denominator in the form of epdf
43        shared_ptr<epdf> den;
44
45        //!flag for destructor
46        bool destroynom;
47        //!datalink between conditional and nom
48        datalink_m2e dl;
49        //! dummy epdf that stores only rv and dim
50        epdf iepdf;
51public:
52        //!Default constructor. By default, the given epdf is not copied!
53        //! It is assumed that this function will be used only temporarily.
54        mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
55                // adjust rv and rvc
56                rvc = nom0->_rv().subt ( rv );
57                dimc = rvc._dsize();
58                set_ep ( iepdf );
59                iepdf.set_parameters ( rv._dsize() );
60                iepdf.set_rv ( rv );
61
62                //prepare data structures
63                if ( copy ) {
64                        it_error ( "todo" );
65                        destroynom = true;
66                } else {
67                        nom = nom0;
68                        destroynom = false;
69                }
70                it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72                // build denominator
73                den = nom->marginal ( rvc );
74                dl.set_connection ( rv, rvc, nom0->_rv() );
75        };
76        double evallogcond ( const vec &val, const vec &cond ) {
77                double tmp;
78                vec nom_val ( dimension() + dimc );
79                dl.pushup_cond ( nom_val, val, cond );
80                tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81                return tmp;
82        }
83        //! Object takes ownership of nom and will destroy it
84        void ownnom() {
85                destroynom = true;
86        }
87        //! Default destructor
88        ~mratio() {
89                if ( destroynom ) {
90                        delete nom;
91                }
92        }
93};
94
95/*!
96* \brief Mixture of epdfs
97
98Density function:
99\f[
100f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
101\f]
102where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
103
104*/
105class emix : public epdf {
106protected:
107        //! weights of the components
108        vec w;
109        //! Component (epdfs)
110        Array<shared_ptr<epdf> > Coms;
111
112public:
113        //!Default constructor
114        emix ( ) : epdf ( ) {};
115        //! Set weights \c w and components \c Coms
116        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
117        void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
118
119        vec sample() const;
120        vec mean() const {
121                int i;
122                vec mu = zeros ( dim );
123                for ( i = 0; i < w.length(); i++ ) {
124                        mu += w ( i ) * Coms ( i )->mean();
125                }
126                return mu;
127        }
128        vec variance() const {
129                //non-central moment
130                vec mom2 = zeros ( dim );
131                for ( int i = 0; i < w.length(); i++ ) {
132                        mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
133                }
134                //central moment
135                return mom2 - pow ( mean(), 2 );
136        }
137        double evallog ( const vec &val ) const {
138                int i;
139                double sum = 0.0;
140                for ( i = 0; i < w.length(); i++ ) {
141                        sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
142                }
143                if ( sum == 0.0 ) {
144                        sum = std::numeric_limits<double>::epsilon();
145                }
146                double tmp = log ( sum );
147                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
148                return tmp;
149        };
150        vec evallog_m ( const mat &Val ) const {
151                vec x = zeros ( Val.cols() );
152                for ( int i = 0; i < w.length(); i++ ) {
153                        x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
154                }
155                return log ( x );
156        };
157        //! Auxiliary function that returns pdflog for each component
158        mat evallog_M ( const mat &Val ) const {
159                mat X ( w.length(), Val.cols() );
160                for ( int i = 0; i < w.length(); i++ ) {
161                        X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
162                }
163                return X;
164        };
165
166        shared_ptr<epdf> marginal ( const RV &rv ) const;
167        //! Update already existing marginal density  \c target
168        void marginal ( const RV &rv, emix &target ) const;
169        shared_ptr<mpdf> condition ( const RV &rv ) const;
170
171//Access methods
172        //! returns a pointer to the internal mean value. Use with Care!
173        vec& _w() {
174                return w;
175        }
176
177        //!access function
178        shared_ptr<epdf> _Coms ( int i ) {
179                return Coms ( i );
180        }
181
182        void set_rv ( const RV &rv ) {
183                epdf::set_rv ( rv );
184                for ( int i = 0; i < Coms.length(); i++ ) {
185                        Coms ( i )->set_rv ( rv );
186                }
187        }
188};
189SHAREDPTR( emix );
190
191/*!
192* \brief Mixture of egiws
193
194*/
195class egiwmix : public egiw {
196protected:
197        //! weights of the components
198        vec w;
199        //! Component (epdfs)
200        Array<egiw*> Coms;
201        //!Flag if owning Coms
202        bool destroyComs;
203public:
204        //!Default constructor
205        egiwmix ( ) : egiw ( ) {};
206
207        //! Set weights \c w and components \c Coms
208        //!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
209        void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
210
211        //!return expected value
212        vec mean() const;
213
214        //!return a sample from the density
215        vec sample() const;
216
217        //!return the expected variance
218        vec variance() const;
219
220        // TODO!!! Defined to follow ANSI and/or for future development
221        void mean_mat ( mat &M, mat&R ) const {};
222        double evallog_nn ( const vec &val ) const {
223                return 0;
224        };
225        double lognc () const {
226                return 0;
227        }
228
229        shared_ptr<epdf> marginal ( const RV &rv ) const;
230        void marginal ( const RV &rv, emix &target ) const;
231
232//Access methods
233        //! returns a pointer to the internal mean value. Use with Care!
234        vec& _w() {
235                return w;
236        }
237        virtual ~egiwmix() {
238                if ( destroyComs ) {
239                        for ( int i = 0; i < Coms.length(); i++ ) {
240                                delete Coms ( i );
241                        }
242                }
243        }
244        //! Auxiliary function for taking ownership of the Coms()
245        void ownComs() {
246                destroyComs = true;
247        }
248
249        //!access function
250        egiw* _Coms ( int i ) {
251                return Coms ( i );
252        }
253
254        void set_rv ( const RV &rv ) {
255                egiw::set_rv ( rv );
256                for ( int i = 0; i < Coms.length(); i++ ) {
257                        Coms ( i )->set_rv ( rv );
258                }
259        }
260
261        //! Approximation of a GiW mix by a single GiW pdf
262        egiw* approx();
263};
264
265/*! \brief Chain rule decomposition of epdf
266
267Probability density in the form of Chain-rule decomposition:
268\[
269f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
270\]
271Note that
272*/
273class mprod: public mpdf {
274private:
275        Array<shared_ptr<mpdf> > mpdfs;
276
277        //! Data link for each mpdfs
278        Array<shared_ptr<datalink_m2m> > dls;
279
280protected:
281        //! dummy epdf used only as storage for RV and dim
282        epdf iepdf;
283
284public:
285        //! \brief Default constructor
286        mprod() { }
287
288        /*!\brief Constructor from list of mFacs
289        */
290        mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
291                set_elements ( mFacs );
292        }
293        //! Set internal \c mpdfs from given values
294        void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
295
296        double evallogcond ( const vec &val, const vec &cond ) {
297                int i;
298                double res = 0.0;
299                for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
300                        /*                      if ( mpdfs(i)->_rvc().count() >0) {
301                                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
302                                                }
303                                                // add logarithms
304                                                res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
305                        res += mpdfs ( i )->evallogcond (
306                                   dls ( i )->pushdown ( val ),
307                                   dls ( i )->get_cond ( val, cond )
308                               );
309                }
310                return res;
311        }
312        vec evallogcond_m ( const mat &Dt, const vec &cond ) {
313                vec tmp ( Dt.cols() );
314                for ( int i = 0; i < Dt.cols(); i++ ) {
315                        tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
316                }
317                return tmp;
318        };
319        vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
320                vec tmp ( Dt.length() );
321                for ( int i = 0; i < Dt.length(); i++ ) {
322                        tmp ( i ) = evallogcond ( Dt ( i ), cond );
323                }
324                return tmp;
325        };
326
327
328        //TODO smarter...
329        vec samplecond ( const vec &cond ) {
330                //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
331                vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
332                vec smpi;
333                // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
334                for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
335                        // generate contribution of this mpdf
336                        smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                       
337                        // copy contribution of this pdf into smp
338                        dls ( i )->pushup ( smp, smpi );
339                }
340                return smp;
341        }
342
343        //! Load from structure with elements:
344        //!  \code
345        //! { class='mprod';
346        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
347        //! }
348        //! \endcode
349        //!@}
350        void from_setting ( const Setting &set ) {
351                Array<shared_ptr<mpdf> > atmp; //temporary Array
352                UI::get ( atmp, set, "mpdfs", UI::compulsory );
353                set_elements ( atmp );
354        }
355
356private:
357        // not implemented
358        mprod ( const mprod & );
359        mprod &operator=( const mprod & );
360};
361UIREGISTER ( mprod );
362SHAREDPTR ( mprod );
363
364//! Product of independent epdfs. For dependent pdfs, use mprod.
365class eprod: public epdf {
366protected:
367        //! Components (epdfs)
368        Array<const epdf*> epdfs;
369        //! Array of indeces
370        Array<datalink*> dls;
371public:
372        //! Default constructor
373        eprod () : epdfs ( 0 ), dls ( 0 ) {};
374        //! Set internal
375        void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
376                epdfs = epdfs0;//.set_length ( epdfs0.length() );
377                dls.set_length ( epdfs.length() );
378
379                bool independent = true;
380                if ( named ) {
381                        for ( int i = 0; i < epdfs.length(); i++ ) {
382                                independent = rv.add ( epdfs ( i )->_rv() );
383                                it_assert_debug ( independent == true, "eprod:: given components are not independent." );
384                        }
385                        dim = rv._dsize();
386                } else {
387                        dim = 0;
388                        for ( int i = 0; i < epdfs.length(); i++ ) {
389                                dim += epdfs ( i )->dimension();
390                        }
391                }
392                //
393                int cumdim = 0;
394                int dimi = 0;
395                int i;
396                for ( i = 0; i < epdfs.length(); i++ ) {
397                        dls ( i ) = new datalink;
398                        if ( named ) {
399                                dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
400                        } else {
401                                dimi = epdfs ( i )->dimension();
402                                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
403                                cumdim += dimi;
404                        }
405                }
406        }
407
408        vec mean() const {
409                vec tmp ( dim );
410                for ( int i = 0; i < epdfs.length(); i++ ) {
411                        vec pom = epdfs ( i )->mean();
412                        dls ( i )->pushup ( tmp, pom );
413                }
414                return tmp;
415        }
416        vec variance() const {
417                vec tmp ( dim ); //second moment
418                for ( int i = 0; i < epdfs.length(); i++ ) {
419                        vec pom = epdfs ( i )->mean();
420                        dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
421                }
422                return tmp - pow ( mean(), 2 );
423        }
424        vec sample() const {
425                vec tmp ( dim );
426                for ( int i = 0; i < epdfs.length(); i++ ) {
427                        vec pom = epdfs ( i )->sample();
428                        dls ( i )->pushup ( tmp, pom );
429                }
430                return tmp;
431        }
432        double evallog ( const vec &val ) const {
433                double tmp = 0;
434                for ( int i = 0; i < epdfs.length(); i++ ) {
435                        tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
436                }
437                it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
438                return tmp;
439        }
440        //!access function
441        const epdf* operator () ( int i ) const {
442                it_assert_debug ( i < epdfs.length(), "wrong index" );
443                return epdfs ( i );
444        }
445
446        //!Destructor
447        ~eprod() {
448                for ( int i = 0; i < epdfs.length(); i++ ) {
449                        delete dls ( i );
450                }
451        }
452};
453
454
455/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC
456
457*/
458class mmix : public mpdf {
459protected:
460        //! Component (mpdfs)
461        Array<shared_ptr<mpdf> > Coms;
462        //!weights of the components
463        vec w;
464        //! dummy epdfs
465        epdf dummy_epdf;
466public:
467        //!Default constructor
468        mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
469
470        //! Set weights \c w and components \c R
471        void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
472                //!\todo check if all components are OK
473                Coms = Coms0;
474                w=w0;   
475
476                if (Coms0.length()>0){
477                        set_rv(Coms(0)->_rv());
478                        dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
479                        set_rvc(Coms(0)->_rvc());
480                        dimc = rvc._dsize();
481                }
482        }
483        double evallogcond (const vec &dt, const vec &cond) {
484                double ll=0.0;
485                for (int i=0;i<Coms.length();i++){
486                        ll+=Coms(i)->evallogcond(dt,cond);
487                }
488                return ll;
489        }
490
491        vec samplecond (const vec &cond);
492
493};
494
495}
496#endif //MX_H
Note: See TracBrowser for help on using the browser.