root/library/bdm/stat/emix.h @ 1064

Revision 1064, 13.3 kB (checked in by mido, 14 years ago)

astyle applied all over the library

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef EMIX_H
14#define EMIX_H
15
16#define LOG2  0.69314718055995
17
18#include "../shared_ptr.h"
19#include "exp_family.h"
20
21namespace bdm {
22
23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
35At present the only supported operation is evallogcond().
36 */
37class mratio: public pdf {
38protected:
39    //! Nominator in the form of pdf
40    const epdf* nom;
41
42    //!Denominator in the form of epdf
43    shared_ptr<epdf> den;
44
45    //!flag for destructor
46    bool destroynom;
47    //!datalink between conditional and nom
48    datalink_m2e dl;
49public:
50    //!Default constructor. By default, the given epdf is not copied!
51    //! It is assumed that this function will be used only temporarily.
52    mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
53        // adjust rv and rvc
54
55        set_rv ( rv );
56        dim = rv._dsize();
57
58        rvc = nom0->_rv().subt ( rv );
59        dimc = rvc._dsize();
60
61        //prepare data structures
62        if ( copy ) {
63            bdm_error ( "todo" );
64            // destroynom = true;
65        } else {
66            nom = nom0;
67            destroynom = false;
68        }
69        bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
70
71        // build denominator
72        den = nom->marginal ( rvc );
73        dl.set_connection ( rv, rvc, nom0->_rv() );
74    };
75
76    double evallogcond ( const vec &val, const vec &cond ) {
77        double tmp;
78        vec nom_val ( dimension() + dimc );
79        dl.pushup_cond ( nom_val, val, cond );
80        tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81        return tmp;
82    }
83
84    //! Returns a sample from the density conditioned on \c cond, \f$x \sim epdf(rv|cond)\f$. \param cond is numeric value of \c rv
85    virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);
86
87    //! Object takes ownership of nom and will destroy it
88    void ownnom() {
89        destroynom = true;
90    }
91    //! Default destructor
92    ~mratio() {
93        if ( destroynom ) {
94            delete nom;
95        }
96    }
97
98
99private:
100    // not implemented
101    mratio ( const mratio & );
102    mratio &operator= ( const mratio & );
103};
104
105class emix; //forward
106
107/*! Base class (Interface) for mixtures
108*/
109class emix_base : public epdf {
110protected:
111    //! reference to vector of weights
112    vec &w;
113    //! function returning ith component
114    virtual const epdf * component(const int &i) const=0;
115
116    virtual int no_coms() const=0;
117
118public:
119
120    emix_base(vec &w0): w(w0) {}
121
122    void validate ();
123
124    vec sample() const;
125
126    vec mean() const;
127
128    vec variance() const;
129
130    double evallog ( const vec &val ) const;
131
132    vec evallog_mat ( const mat &Val ) const;
133
134    //! Auxiliary function that returns pdflog for each component
135    mat evallog_coms ( const mat &Val ) const;
136
137    shared_ptr<epdf> marginal ( const RV &rv ) const;
138    //! Update already existing marginal density  \c target
139    void marginal ( const RV &rv, emix &target ) const;
140    shared_ptr<pdf> condition ( const RV &rv ) const;
141
142    //Access methods
143    //! returns a reference to the internal weights. Use with Care!
144    vec& _w() {
145        return w;
146    }
147
148    const vec& _w() const {
149        return w;
150    }
151    //!access
152    const epdf* _com(int i) const {
153        return component(i);
154    }
155
156};
157
158/*!
159* \brief Mixture of epdfs
160
161Density function:
162\f[
163f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
164\f]
165where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
166
167*/
168class emix : public emix_base {
169protected:
170    //! weights of the components
171    vec weights;
172
173    //! Component (epdfs)
174    Array<shared_ptr<epdf> > Coms;
175
176public:
177    //! Default constructor
178    emix ( ) : emix_base ( weights) { }
179
180    const epdf* component(const int &i) const {
181        return Coms(i).get();
182    }
183    void validate();
184
185
186    int no_coms() const {
187        return Coms.length();
188    }
189
190    //! Load from structure with elements:
191    //!  \code
192    //! { class='emix';
193    //!   pdfs = (..., ...);     // list of pdfs in the mixture
194    //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
195    //! }
196    //! \endcode
197    //!@}
198    void from_setting ( const Setting &set );
199    void to_setting  (Setting  &set) const;
200
201    void set_rv ( const RV &rv ) {
202        epdf::set_rv ( rv );
203        for ( int i = 0; i < no_coms(); i++ ) {
204            Coms( i )->set_rv ( rv );
205        }
206    }
207
208    Array<shared_ptr<epdf> >& _Coms ( ) {
209        return Coms;
210    }
211};
212SHAREDPTR ( emix );
213UIREGISTER ( emix );
214
215
216/*! \brief Chain rule decomposition of epdf
217
218Probability density in the form of Chain-rule decomposition:
219\[
220f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
221\]
222Note that
223*/
224class mprod: public pdf {
225private:
226    Array<shared_ptr<pdf> > pdfs;
227
228    //! Data link for each pdfs
229    Array<shared_ptr<datalink_m2m> > dls;
230
231public:
232    //! \brief Default constructor
233    mprod() { }
234
235    /*!\brief Constructor from list of mFacs
236    */
237    mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
238        set_elements ( mFacs );
239    }
240    //! Set internal \c pdfs from given values
241    void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
242
243    double evallogcond ( const vec &val, const vec &cond );
244
245    vec evallogcond_mat ( const mat &Dt, const vec &cond );
246
247    vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );
248
249    //TODO smarter...
250    vec samplecond ( const vec &cond ) {
251        //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
252        vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
253        vec smpi;
254        // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
255        for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
256            // generate contribution of this pdf
257            smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
258            // copy contribution of this pdf into smp
259            dls ( i )->pushup ( smp, smpi );
260        }
261        return smp;
262    }
263
264    //! Load from structure with elements:
265    //!  \code
266    //! { class='mprod';
267    //!   pdfs = (..., ...);     // list of pdfs in the order of chain rule
268    //! }
269    //! \endcode
270    //!@}
271    void from_setting ( const Setting &set ) ;
272    void        to_setting  (Setting  &set) const;
273
274
275};
276UIREGISTER ( mprod );
277SHAREDPTR ( mprod );
278
279
280//! Product of independent epdfs. For dependent pdfs, use mprod.
281class eprod_base: public epdf {
282protected:
283    //! Array of indices
284    Array<datalink*> dls;
285    //! interface for a factor
286public:
287    virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL);
288    //!number of factors
289    virtual const int no_factors() const NOT_IMPLEMENTED(0);
290    //! Default constructor
291    eprod_base () :  dls (0) {};
292    //! Set internal
293    vec mean() const;
294
295    vec variance() const;
296
297    vec sample() const;
298
299    double evallog ( const vec &val ) const;
300
301    //!Destructor
302    ~eprod_base() {
303        for ( int i = 0; i < dls.length(); i++ ) {
304            delete dls ( i );
305        }
306    }
307    void validate() {
308        epdf::validate();
309        dls.set_length ( no_factors() );
310
311        bool independent = true;
312        dim = 0;
313        for ( int i = 0; i < no_factors(); i++ ) {
314            independent = rv.add ( factor ( i )->_rv() );
315            dim += factor ( i )->dimension();
316            bdm_assert_debug ( independent, "eprod:: given components are not independent." );
317        };
318
319        //
320        int cumdim = 0;
321        int dimi = 0;
322        int i;
323        for ( i = 0; i < no_factors(); i++ ) {
324            dls ( i ) = new datalink;
325            if ( isnamed() ) { // rvs are complete
326                dls ( i )->set_connection ( factor ( i )->_rv() , rv );
327            } else { //rvs are not reliable
328                dimi = factor ( i )->dimension();
329                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
330                cumdim += dimi;
331            }
332        }
333
334    }
335};
336
337class eprod: public eprod_base {
338protected:
339    Array<shared_ptr<epdf> > factors;
340public:
341    const epdf* factor(int i) const {
342        return factors(i).get();
343    }
344    const int no_factors() const {
345        return factors.length();
346    }
347    void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
348        factors = epdfs0;
349    }
350    void from_setting(const Setting &set) {
351        UI::get(factors,set,"pdfs",UI::compulsory);
352    }
353};
354UIREGISTER(eprod);
355
356//! similar to eprod but used only internally -- factors are external pointers
357class eprod_internal: public eprod_base {
358protected:
359    Array<epdf* > factors;
360    const epdf* factor(int i) const {
361        return factors(i);
362    }
363    const int no_factors() const {
364        return factors.length();
365    }
366public:
367    void set_parameters ( const Array<epdf *> &epdfs0) {
368        factors = epdfs0;
369    }
370};
371
372/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
373
374*/
375class mmix : public pdf {
376protected:
377    //! Component (pdfs)
378    Array<shared_ptr<pdf> > Coms;
379    //!weights of the components
380    vec w;
381public:
382    //!Default constructor
383    mmix() : Coms ( 0 ) { };
384
385    double evallogcond ( const vec &dt, const vec &cond ) {
386        double ll = 0.0;
387        for ( int i = 0; i < Coms.length(); i++ ) {
388            ll += Coms ( i )->evallogcond ( dt, cond );
389        }
390        return ll;
391    }
392
393    vec samplecond ( const vec &cond );
394
395    //! Load from structure with elements:
396    //!  \code
397    //! { class='mmix';
398    //!   pdfs = (..., ...);     // list of pdfs in the mixture
399    //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
400    //! }
401    //! \endcode
402    //!@}
403    void from_setting ( const Setting &set );
404    void        to_setting  (Setting  &set) const;
405    virtual void validate();
406};
407SHAREDPTR ( mmix );
408UIREGISTER ( mmix );
409
410
411//! Base class for all BM running as parallel update of internal BMs
412
413class ProdBMBase : public BM {
414protected :
415    Array<vec_from_vec> bm_yt;
416    Array<vec_from_2vec> bm_cond;
417    class eprod_bm : public eprod_base {
418        ProdBMBase & pb;
419    public :
420        eprod_bm(ProdBMBase &pb0): pb(pb0) {}
421        const epdf* factor(int i ) const {
422            return &(pb.bm(i)->posterior());
423        }
424        const int no_factors() const {
425            return pb.no_bms();
426        }
427    } est;
428public:
429    ProdBMBase():est(*this) {}
430    virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
431    virtual int no_bms() const {
432        return 0;
433    }
434    const epdf& posterior() const {
435        return est;
436    }
437    void set_prior(const epdf *pri) {
438        const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
439        if (ep) {
440            bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
441                       num2str(no_bms()) + "BMs");
442            for (int i=0; i<no_bms(); i++) {
443                bm(i)->set_prior(ep->factor(i));
444            }
445        }
446    }
447
448    void validate() {
449        est.validate();
450        BM::validate();
451        // set links
452        bm_yt.set_length(no_bms());
453        bm_cond.set_length(no_bms());
454
455        //
456
457        for (int i=0; i<no_bms(); i++) {
458            yrv.add(bm(i)->_yrv());
459            rvc.add(bm(i)->_rvc());
460        }
461        rvc=rvc.subt(yrv);
462
463        dimy = yrv._dsize();
464        dimc = rvc._dsize();
465
466        for (int i=0; i<no_bms(); i++) {
467            bm_yt(i).connect(bm(i)->_yrv(), yrv);
468            bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
469        }
470    }
471    void bayes(const vec &dt, const vec &cond) {
472        ll=0;
473        for(int i=0; i<no_bms(); i++) {
474            bm_yt(i).update(dt);
475            bm_cond(i).update(dt,cond);
476            bm(i)->bayes(bm_yt(i), bm_cond(i));
477        }
478    }
479
480};
481
482class ProdBM: public ProdBMBase {
483protected:
484    Array<shared_ptr<BM> > BMs;
485public:
486    virtual BM* bm(int i) {
487        return BMs(i).get();
488    }
489    virtual int no_bms() const {
490        return BMs.length();
491    }
492    void from_setting(const Setting &set) {
493        BM::from_setting(set);
494        UI::get(BMs,set,"BMs");
495    }
496    void to_setting(Setting &set) const {
497        BM::to_setting(set);
498        UI::save(BMs,set,"BMs");
499    }
500
501};
502UIREGISTER(ProdBM);
503
504//! \brief class for on-line model selection
505class ModelComparator: public ProdBM {
506protected:
507    multiBM weights;
508public:
509    void bayes(const vec &yt, const vec &cond) {
510        vec w_nn(no_bms());
511        for (int i=0; i<no_bms(); i++) {
512            bm(i)->bayes(yt,cond);
513            w_nn(i) += bm(i)->_ll();
514        }
515        vec w=exp(w_nn-max(w_nn));
516        weights.bayes(w/sum(w));
517    }
518    void validate() {
519        ProdBM::validate();
520        weights.validate();
521    }
522    void from_setting(const Setting& set) {
523        ProdBM::from_setting(set);
524        UI::get(weights.frg, set, "frg",UI::optional);
525    }
526    void to_setting(Setting& set) const {
527        ProdBM::to_setting(set);
528        UI::save(weights.frg, set, "frg");
529    }
530};
531
532}
533#endif //MX_H
Note: See TracBrowser for help on using the browser.