root/library/bdm/stat/emix.h @ 1187

Revision 1170, 14.5 kB (checked in by smidl, 14 years ago)

New noise particle + memory leak fix

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26
27which arise e.g. by applying the Bayes rule.
28It represents density in the form:
29\f[
30f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
31\f]
32where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
33
34In particular this type of arise by conditioning of a mixture model.
35
36At present the only supported operation is evallogcond().
37 */
38class mratio: public pdf {
39protected:
40    //! Nominator in the form of pdf
41    const epdf* nom;
42
43    //!Denominator in the form of epdf
44    shared_ptr<epdf> den;
45
46    //!flag for destructor
47    bool destroynom;
48    //!datalink between conditional and nom
49    datalink_m2e dl;
50public:
51    //!Default constructor. By default, the given epdf is not copied!
52    //! It is assumed that this function will be used only temporarily.
53    mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
54        // adjust rv and rvc
55
56        set_rv ( rv );
57        dim = rv._dsize();
58
59        rvc = nom0->_rv().subt ( rv );
60        dimc = rvc._dsize();
61
62        //prepare data structures
63        if ( copy ) {
64            bdm_error ( "todo" );
65            // destroynom = true;
66        } else {
67            nom = nom0;
68            destroynom = false;
69        }
70        bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
71
72        // build denominator
73        den = nom->marginal ( rvc );
74        dl.set_connection ( rv, rvc, nom0->_rv() );
75    };
76
77    double evallogcond ( const vec &val, const vec &cond ) {
78        double tmp;
79        vec nom_val ( dimension() + dimc );
80        dl.pushup_cond ( nom_val, val, cond );
81        tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
82        return tmp;
83    }
84
85    //! Returns a sample from the density conditioned on \c cond, \f$x \sim epdf(rv|cond)\f$. \param cond is numeric value of \c rv
86    virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);
87
88    //! Object takes ownership of nom and will destroy it
89    void ownnom() {
90        destroynom = true;
91    }
92    //! Default destructor
93    ~mratio() {
94        if ( destroynom ) {
95            delete nom;
96        }
97    }
98
99
100private:
101    // not implemented
102    mratio ( const mratio & );
103    mratio &operator= ( const mratio & );
104};
105
106class emix; //forward
107
108//! \brief Base class (interface) for mixtures
109class emix_base : public epdf {
110protected:
111    //! reference to vector of weights
112    vec &w;
113    //! function returning ith component
114    virtual const epdf * component(const int &i) const=0;
115
116    virtual int no_coms() const=0;
117
118public:
119
120    emix_base(vec &w0): w(w0) {}
121
122    void validate ();
123
124    vec sample() const;
125
126    vec mean() const;
127
128    vec variance() const;
129
130    double evallog ( const vec &val ) const;
131
132    vec evallog_mat ( const mat &Val ) const;
133
134    //! Auxiliary function that returns pdflog for each component
135    mat evallog_coms ( const mat &Val ) const;
136
137    shared_ptr<epdf> marginal ( const RV &rv ) const;
138    //! Update already existing marginal density  \c target
139    void marginal ( const RV &rv, emix &target ) const;
140    shared_ptr<pdf> condition ( const RV &rv ) const;
141
142    //Access methods
143    //! returns a reference to the internal weights. Use with Care!
144    vec& _w() {
145        return w;
146    }
147
148    const vec& _w() const {
149        return w;
150    }
151    //!access
152    const epdf* _com(int i) const {
153        return component(i);
154    }
155
156};
157
158/*!
159* \brief Mixture of epdfs
160
161Density function:
162\f[
163f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
164\f]
165where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
166
167*/
168class emix : public emix_base {
169protected:
170    //! weights of the components
171    vec weights;
172
173    //! Component (epdfs)
174    Array<shared_ptr<epdf> > Coms;
175
176public:
177    //! Default constructor
178    emix ( ) : emix_base ( weights) { }
179
180    const epdf* component(const int &i) const {
181        return Coms(i).get();
182    }
183    void validate();
184
185
186    int no_coms() const {
187        return Coms.length();
188    }
189
190    /*! Create object from the following structure
191
192    \code
193    class = 'emix';
194
195    pdfs = { list of any bdm::pdf offsprings };   % pdfs in the mixture, bdm::pdf::from_setting
196    weights = [... ];                             % vector of weights of pdfs in the mixture
197    --- inherited fields ---
198    bdm::emix_base::from_setting
199    \endcode
200
201    */
202    void from_setting ( const Setting &set );
203
204    void to_setting  (Setting  &set) const;
205
206    void set_rv ( const RV &rv ) {
207        epdf::set_rv ( rv );
208        for ( int i = 0; i < no_coms(); i++ ) {
209            Coms( i )->set_rv ( rv );
210        }
211    }
212
213    Array<shared_ptr<epdf> >& _Coms ( ) {
214        return Coms;
215    }
216};
217SHAREDPTR ( emix );
218UIREGISTER ( emix );
219
220
221/*! \brief Chain rule decomposition of epdf
222
223Probability density in the form of Chain-rule decomposition:
224\[
225f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
226\]
227Note that
228*/
229class mprod: public pdf {
230private:
231    Array<shared_ptr<pdf> > pdfs;
232
233    //! Data link for each pdfs
234    Array<shared_ptr<datalink_m2m> > dls;
235
236public:
237    //! \brief Default constructor
238    mprod() { }
239
240    /*!\brief Constructor from list of mFacs
241    */
242    mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
243        set_elements ( mFacs );
244    }
245    //! Set internal \c pdfs from given values
246    void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
247
248    double evallogcond ( const vec &val, const vec &cond );
249
250    vec evallogcond_mat ( const mat &Dt, const vec &cond );
251
252    vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );
253
254    //TODO smarter...
255    vec samplecond ( const vec &cond ) {
256        //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
257        vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
258        vec smpi;
259        // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
260        for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
261            // generate contribution of this pdf
262            smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
263            // copy contribution of this pdf into smp
264            dls ( i )->pushup ( smp, smpi );
265        }
266        return smp;
267    }
268
269    //! Create object from the following structure
270    //! \code
271    //! class='mprod';
272    //! pdfs = { list of bdm::pdf };          % list of bdm::pdf offsprings stored in the order of chain rule, bdm::pdf::from_setting
273    //! --- inherited fields ---
274    //! bdm::pdf::from_setting
275    //! \endcode
276    //!@}
277    void from_setting ( const Setting &set ) ;
278    void to_setting  (Setting  &set) const;   
279};
280UIREGISTER ( mprod );
281SHAREDPTR ( mprod );
282
283
284//! \brief Base class (interface) for bdm::eprod
285class eprod_base: public epdf {
286protected:
287    //! Array of indices
288    Array<datalink*> dls;
289    //! interface for a factor
290public:
291    virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL);
292    //!number of factors
293    virtual const int no_factors() const NOT_IMPLEMENTED(0);
294    //! Default constructor
295    eprod_base () :  dls (0) {};
296    //! Set internal
297    vec mean() const;
298
299    vec variance() const;
300
301    vec sample() const;
302
303    double evallog ( const vec &val ) const;
304
305    //!Destructor
306    ~eprod_base() {
307        for ( int i = 0; i < dls.length(); i++ ) {
308            delete dls ( i );
309        }
310    }
311    void validate() {
312        epdf::validate();
313        dls.set_length ( no_factors() );
314
315        bool independent = true;
316        dim = 0;
317                rv = RV();
318        for ( int i = 0; i < no_factors(); i++ ) {
319            independent = rv.add ( factor ( i )->_rv() );
320            dim += factor ( i )->dimension();
321            bdm_assert_debug ( independent, "eprod:: given components are not independent." );
322        };
323
324        //
325        int cumdim = 0;
326        int dimi = 0;
327        int i;
328        for ( i = 0; i < no_factors(); i++ ) {
329                        if (!dls(i)){
330                                dls ( i ) = new datalink;
331                        }
332            if ( isnamed() ) { // rvs are complete
333                dls ( i )->set_connection ( factor ( i )->_rv() , rv );
334            } else { //rvs are not reliable
335                dimi = factor ( i )->dimension();
336                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
337                cumdim += dimi;
338            }
339        }
340
341    }
342};
343
344//! \brief Product of independent epdfs. For dependent pdfs, use bdm::mprod.
345class eprod: public eprod_base {
346protected:
347    Array<shared_ptr<epdf> > factors;
348public:
349    const epdf* factor(int i) const {
350        return factors(i).get();
351    }
352    const int no_factors() const {
353        return factors.length();
354    }
355    void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
356        factors = epdfs0;
357    }
358
359    /*! Create object from the following structure
360
361    \code
362    class = 'eprod';
363    pdfs = { list of any bdm::epdf offsprings };   % pdfs in the product, bdm::epdf::from_setting
364    --- inherited fields ---
365    bdm::eprod_base::from_setting
366    \endcode
367
368    */
369    void from_setting(const Setting &set) {
370        UI::get(factors,set,"pdfs",UI::compulsory);
371    }
372};
373UIREGISTER(eprod);
374
375//! \brief Internal class similar to eprod - factors are external pointers. To be used internally! 
376class eprod_internal: public eprod_base {
377protected:
378    Array<epdf* > factors;
379    const epdf* factor(int i) const {
380        return factors(i);
381    }
382    const int no_factors() const {
383        return factors.length();
384    }
385public:
386    void set_parameters ( const Array<epdf *> &epdfs0) {
387        factors = epdfs0;
388    }
389};
390
391/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
392
393*/
394class mmix : public pdf {
395protected:
396    //! Component (pdfs)
397    Array<shared_ptr<pdf> > Coms;
398    //!weights of the components
399    vec w;
400public:
401    //!Default constructor
402    mmix() : Coms ( 0 ) { };
403
404    double evallogcond ( const vec &dt, const vec &cond ) {
405        double ll = 0.0;
406        for ( int i = 0; i < Coms.length(); i++ ) {
407            ll += Coms ( i )->evallogcond ( dt, cond );
408        }
409        return ll;
410    }
411
412    vec samplecond ( const vec &cond );
413
414    //! }
415    //! \endcode
416    //!@}
417   
418    /*! Create object from the following structure
419    \code
420    class = 'mmix';
421    pdfs = { list of components bdm::pdf };          % list of pdf offsprings, bdm::pdf::from_setting
422    --- optional fields ---
423    weights = [...];                                 % weights of pdfs in the mixture
424    --- inherited fields ---
425    bdm::pdf::from_setting
426    \endcode
427    \endcode
428    If the optional fields are not given, they will be filled as follows:
429    \code
430    weights = 1/n * [1,1,1,...];
431    \endcode
432    */
433    void from_setting ( const Setting &set );
434    void     to_setting  (Setting  &set) const;
435    virtual void validate();
436};
437SHAREDPTR ( mmix );
438UIREGISTER ( mmix );
439
440
441//! \brief Base class for all BM running as parallel update of internal BMs
442class ProdBMBase : public BM {
443protected :
444    Array<vec_from_vec> bm_yt;
445    Array<vec_from_2vec> bm_cond;
446
447    //! \brief Internal class
448    class eprod_bm : public eprod_base {
449        ProdBMBase & pb;
450    public :
451        eprod_bm(ProdBMBase &pb0): pb(pb0) {}
452        const epdf* factor(int i ) const {
453            return &(pb.bm(i)->posterior());
454        }
455        const int no_factors() const {
456            return pb.no_bms();
457        }
458    } est;
459public:
460    ProdBMBase():est(*this) {}
461    virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
462    virtual int no_bms() const {
463        return 0;
464    }
465    const epdf& posterior() const {
466        return est;
467    }
468    void set_prior(const epdf *pri) {
469        const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
470        if (ep) {
471            bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
472                       num2str(no_bms()) + "BMs");
473            for (int i=0; i<no_bms(); i++) {
474                bm(i)->set_prior(ep->factor(i));
475            }
476        }
477    }
478
479    void validate() {
480        est.validate();
481        BM::validate();
482        // set links
483        bm_yt.set_length(no_bms());
484        bm_cond.set_length(no_bms());
485
486        //
487
488        for (int i=0; i<no_bms(); i++) {
489            yrv.add(bm(i)->_yrv());
490            rvc.add(bm(i)->_rvc());
491        }
492        rvc=rvc.subt(yrv);
493
494        dimy = yrv._dsize();
495        dimc = rvc._dsize();
496
497        for (int i=0; i<no_bms(); i++) {
498            bm_yt(i).connect(bm(i)->_yrv(), yrv);
499            bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
500        }
501    }
502    void bayes(const vec &dt, const vec &cond) {
503        ll=0;
504        for(int i=0; i<no_bms(); i++) {
505            bm_yt(i).update(dt);
506            bm_cond(i).update(dt,cond);
507            bm(i)->bayes(bm_yt(i), bm_cond(i));
508        }
509    }
510
511};
512
513class ProdBM: public ProdBMBase {
514protected:
515    Array<shared_ptr<BM> > BMs;
516public:
517    virtual BM* bm(int i) {
518        return BMs(i).get();
519    }
520    virtual int no_bms() const {
521        return BMs.length();
522    }
523    void from_setting(const Setting &set) {
524        BM::from_setting(set);
525        UI::get(BMs,set,"BMs");
526    }
527    void to_setting(Setting &set) const {
528        BM::to_setting(set);
529        UI::save(BMs,set,"BMs");
530    }
531
532};
533UIREGISTER(ProdBM);
534
535//! \brief class for on-line model selection
536class ModelComparator: public ProdBM {
537protected:
538    multiBM weights;
539public:
540    void bayes(const vec &yt, const vec &cond) {
541        vec w_nn(no_bms());
542        for (int i=0; i<no_bms(); i++) {
543            bm(i)->bayes(yt,cond);
544            w_nn(i) += bm(i)->_ll();
545        }
546        vec w=exp(w_nn-max(w_nn));
547        weights.bayes(w/sum(w));
548    }
549    void validate() {
550        ProdBM::validate();
551        weights.validate();
552    }
553
554   /*! Create object from the following structure
555
556    \code
557    class = 'ModelComparator';
558    --- optional fields ---
559    frg = [...];                  % vector of weights
560    --- inherited fields ---
561    bdm::ProdBM::from_setting
562    \endcode
563    */
564    void from_setting(const Setting& set) {
565        ProdBM::from_setting(set);
566        UI::get(weights.frg, set, "frg",UI::optional);
567    }
568
569    void to_setting(Setting& set) const {
570        ProdBM::to_setting(set);
571        UI::save(weights.frg, set, "frg");
572    }
573};
574
575}
576#endif //MX_H
Note: See TracBrowser for help on using the browser.