root/library/bdm/stat/emix.h @ 1066

Revision 1066, 13.8 kB (checked in by mido, 14 years ago)

another part - all conditional pdfs untill bdm::euni

  • Property svn:eol-style set to native
RevLine 
[107]1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[394]13#ifndef EMIX_H
14#define EMIX_H
[107]15
[477]16#define LOG2  0.69314718055995
[333]17
[461]18#include "../shared_ptr.h"
[384]19#include "exp_family.h"
[107]20
[286]21namespace bdm {
[107]22
[182]23//this comes first because it is used inside emix!
24
25/*! \brief Class representing ratio of two densities
26which arise e.g. by applying the Bayes rule.
27It represents density in the form:
28\f[
29f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
30\f]
31where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.
32
33In particular this type of arise by conditioning of a mixture model.
34
[211]35At present the only supported operation is evallogcond().
[182]36 */
[693]37class mratio: public pdf {
[192]38protected:
[1064]39    //! Nominator in the form of pdf
40    const epdf* nom;
[504]41
[1064]42    //!Denominator in the form of epdf
43    shared_ptr<epdf> den;
[504]44
[1064]45    //!flag for destructor
46    bool destroynom;
47    //!datalink between conditional and nom
48    datalink_m2e dl;
[192]49public:
[1064]50    //!Default constructor. By default, the given epdf is not copied!
51    //! It is assumed that this function will be used only temporarily.
52    mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
53        // adjust rv and rvc
[675]54
[1064]55        set_rv ( rv );
56        dim = rv._dsize();
[675]57
[1064]58        rvc = nom0->_rv().subt ( rv );
59        dimc = rvc._dsize();
[737]60
[1064]61        //prepare data structures
62        if ( copy ) {
63            bdm_error ( "todo" );
64            // destroynom = true;
65        } else {
66            nom = nom0;
67            destroynom = false;
68        }
69        bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
[477]70
[1064]71        // build denominator
72        den = nom->marginal ( rvc );
73        dl.set_connection ( rv, rvc, nom0->_rv() );
74    };
[675]75
[1064]76    double evallogcond ( const vec &val, const vec &cond ) {
77        double tmp;
78        vec nom_val ( dimension() + dimc );
79        dl.pushup_cond ( nom_val, val, cond );
80        tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
81        return tmp;
82    }
[766]83
[1064]84    //! Returns a sample from the density conditioned on \c cond, \f$x \sim epdf(rv|cond)\f$. \param cond is numeric value of \c rv
85    virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);
[766]86
[1064]87    //! Object takes ownership of nom and will destroy it
88    void ownnom() {
89        destroynom = true;
90    }
91    //! Default destructor
92    ~mratio() {
93        if ( destroynom ) {
94            delete nom;
95        }
96    }
[550]97
[766]98
[550]99private:
[1064]100    // not implemented
101    mratio ( const mratio & );
102    mratio &operator= ( const mratio & );
[182]103};
104
[886]105class emix; //forward
106
[1066]107//! \brief Base class (interface) for mixtures
[886]108class emix_base : public epdf {
[1064]109protected:
110    //! reference to vector of weights
111    vec &w;
112    //! function returning ith component
113    virtual const epdf * component(const int &i) const=0;
114
115    virtual int no_coms() const=0;
116
117public:
118
119    emix_base(vec &w0): w(w0) {}
120
121    void validate ();
122
123    vec sample() const;
124
125    vec mean() const;
126
127    vec variance() const;
128
129    double evallog ( const vec &val ) const;
130
131    vec evallog_mat ( const mat &Val ) const;
132
133    //! Auxiliary function that returns pdflog for each component
134    mat evallog_coms ( const mat &Val ) const;
135
136    shared_ptr<epdf> marginal ( const RV &rv ) const;
137    //! Update already existing marginal density  \c target
138    void marginal ( const RV &rv, emix &target ) const;
139    shared_ptr<pdf> condition ( const RV &rv ) const;
140
141    //Access methods
142    //! returns a reference to the internal weights. Use with Care!
143    vec& _w() {
144        return w;
145    }
146
147    const vec& _w() const {
148        return w;
149    }
150    //!access
151    const epdf* _com(int i) const {
152        return component(i);
153    }
154
[886]155};
156
[107]157/*!
158* \brief Mixture of epdfs
159
160Density function:
161\f[
162f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
163\f]
164where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
165
166*/
[886]167class emix : public emix_base {
[162]168protected:
[1064]169    //! weights of the components
170    vec weights;
[559]171
[1064]172    //! Component (epdfs)
173    Array<shared_ptr<epdf> > Coms;
[504]174
[162]175public:
[1064]176    //! Default constructor
177    emix ( ) : emix_base ( weights) { }
[559]178
[1064]179    const epdf* component(const int &i) const {
180        return Coms(i).get();
181    }
182    void validate();
[107]183
184
[1064]185    int no_coms() const {
186        return Coms.length();
187    }
188
[1066]189    /*! Create object from the following structure
190
191    \code
192    class = 'emix';
193
194    pdfs = { list of any bdm::pdf offsprings };   % pdfs in the mixture, bdm::pdf::from_setting
195    weights = [... ];                             % vector of weights of pdfs in the mixture
196    --- inherited fields ---
197    bdm::emix_base::from_setting
198    \endcode
199
200    */
[1064]201    void from_setting ( const Setting &set );
[1066]202
[1064]203    void to_setting  (Setting  &set) const;
204
205    void set_rv ( const RV &rv ) {
206        epdf::set_rv ( rv );
207        for ( int i = 0; i < no_coms(); i++ ) {
208            Coms( i )->set_rv ( rv );
209        }
210    }
211
212    Array<shared_ptr<epdf> >& _Coms ( ) {
213        return Coms;
214    }
[333]215};
[886]216SHAREDPTR ( emix );
217UIREGISTER ( emix );
[333]218
[886]219
[115]220/*! \brief Chain rule decomposition of epdf
221
[145]222Probability density in the form of Chain-rule decomposition:
223\[
224f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
225\]
226Note that
[115]227*/
[693]228class mprod: public pdf {
[507]229private:
[1064]230    Array<shared_ptr<pdf> > pdfs;
[507]231
[1064]232    //! Data link for each pdfs
233    Array<shared_ptr<datalink_m2m> > dls;
[461]234
[162]235public:
[1064]236    //! \brief Default constructor
237    mprod() { }
[507]238
[1064]239    /*!\brief Constructor from list of mFacs
240    */
241    mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
242        set_elements ( mFacs );
243    }
244    //! Set internal \c pdfs from given values
245    void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
[477]246
[1064]247    double evallogcond ( const vec &val, const vec &cond );
[395]248
[1064]249    vec evallogcond_mat ( const mat &Dt, const vec &cond );
[395]250
[1064]251    vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );
[739]252
[1064]253    //TODO smarter...
254    vec samplecond ( const vec &cond ) {
255        //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
256        vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
257        vec smpi;
258        // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
259        for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
260            // generate contribution of this pdf
261            smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
262            // copy contribution of this pdf into smp
263            dls ( i )->pushup ( smp, smpi );
264        }
265        return smp;
266    }
[162]267
[1064]268    //! Load from structure with elements:
269    //!  \code
270    //! { class='mprod';
271    //!   pdfs = (..., ...);     // list of pdfs in the order of chain rule
272    //! }
273    //! \endcode
274    //!@}
275    void from_setting ( const Setting &set ) ;
[1066]276    void     to_setting  (Setting  &set) const;
[956]277
278
[107]279};
[477]280UIREGISTER ( mprod );
[529]281SHAREDPTR ( mprod );
[107]282
[886]283
[1066]284//! \brief Base class (interface) for bdm::eprod
[886]285class eprod_base: public epdf {
[168]286protected:
[1064]287    //! Array of indices
288    Array<datalink*> dls;
289    //! interface for a factor
[979]290public:
[1064]291    virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL);
292    //!number of factors
293    virtual const int no_factors() const NOT_IMPLEMENTED(0);
294    //! Default constructor
295    eprod_base () :  dls (0) {};
296    //! Set internal
297    vec mean() const;
[886]298
[1064]299    vec variance() const;
[886]300
[1064]301    vec sample() const;
[886]302
[1064]303    double evallog ( const vec &val ) const;
[886]304
[1064]305    //!Destructor
306    ~eprod_base() {
307        for ( int i = 0; i < dls.length(); i++ ) {
308            delete dls ( i );
309        }
310    }
311    void validate() {
312        epdf::validate();
313        dls.set_length ( no_factors() );
314
315        bool independent = true;
316        dim = 0;
317        for ( int i = 0; i < no_factors(); i++ ) {
318            independent = rv.add ( factor ( i )->_rv() );
319            dim += factor ( i )->dimension();
320            bdm_assert_debug ( independent, "eprod:: given components are not independent." );
321        };
322
323        //
324        int cumdim = 0;
325        int dimi = 0;
326        int i;
327        for ( i = 0; i < no_factors(); i++ ) {
328            dls ( i ) = new datalink;
329            if ( isnamed() ) { // rvs are complete
330                dls ( i )->set_connection ( factor ( i )->_rv() , rv );
331            } else { //rvs are not reliable
332                dimi = factor ( i )->dimension();
333                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
334                cumdim += dimi;
335            }
336        }
337
338    }
[886]339};
[168]340
[1066]341//! \brief Product of independent epdfs. For dependent pdfs, use bdm::mprod.
[1064]342class eprod: public eprod_base {
343protected:
344    Array<shared_ptr<epdf> > factors;
345public:
346    const epdf* factor(int i) const {
347        return factors(i).get();
348    }
349    const int no_factors() const {
350        return factors.length();
351    }
352    void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
353        factors = epdfs0;
354    }
[1066]355
356    /*! Create object from the following structure
357
358    \code
359    class = 'eprod';
360    pdfs = { list of any bdm::epdf offsprings };   % pdfs in the product, bdm::epdf::from_setting
361    --- inherited fields ---
362    bdm::eprod_base::from_setting
363    \endcode
364
365    */
[1064]366    void from_setting(const Setting &set) {
367        UI::get(factors,set,"pdfs",UI::compulsory);
368    }
[886]369};
[944]370UIREGISTER(eprod);
[204]371
[1066]372//! \brief Internal class similar to eprod - factors are external pointers. To be used only internally! 
[1064]373class eprod_internal: public eprod_base {
374protected:
375    Array<epdf* > factors;
376    const epdf* factor(int i) const {
377        return factors(i);
378    }
379    const int no_factors() const {
380        return factors.length();
381    }
382public:
383    void set_parameters ( const Array<epdf *> &epdfs0) {
384        factors = epdfs0;
385    }
[168]386};
387
[693]388/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC
[124]389
390*/
[693]391class mmix : public pdf {
[488]392protected:
[1064]393    //! Component (pdfs)
394    Array<shared_ptr<pdf> > Coms;
395    //!weights of the components
396    vec w;
[488]397public:
[1064]398    //!Default constructor
399    mmix() : Coms ( 0 ) { };
[461]400
[1064]401    double evallogcond ( const vec &dt, const vec &cond ) {
402        double ll = 0.0;
403        for ( int i = 0; i < Coms.length(); i++ ) {
404            ll += Coms ( i )->evallogcond ( dt, cond );
405        }
406        return ll;
407    }
[488]408
[1064]409    vec samplecond ( const vec &cond );
[488]410
[1064]411    //! Load from structure with elements:
412    //!  \code
413    //! { class='mmix';
414    //!   pdfs = (..., ...);     // list of pdfs in the mixture
415    //!   weights = ( 0.5, 0.5 ); // weights of pdfs in the mixture
416    //! }
417    //! \endcode
418    //!@}
419    void from_setting ( const Setting &set );
[1066]420    void     to_setting  (Setting  &set) const;
[1064]421    virtual void validate();
[488]422};
[737]423SHAREDPTR ( mmix );
[711]424UIREGISTER ( mmix );
[488]425
[979]426
[1066]427//! \brief Base class for all BM running as parallel update of internal BMs
[979]428class ProdBMBase : public BM {
[1064]429protected :
430    Array<vec_from_vec> bm_yt;
431    Array<vec_from_2vec> bm_cond;
432    class eprod_bm : public eprod_base {
433        ProdBMBase & pb;
434    public :
435        eprod_bm(ProdBMBase &pb0): pb(pb0) {}
436        const epdf* factor(int i ) const {
437            return &(pb.bm(i)->posterior());
438        }
439        const int no_factors() const {
440            return pb.no_bms();
441        }
442    } est;
443public:
444    ProdBMBase():est(*this) {}
445    virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
446    virtual int no_bms() const {
447        return 0;
448    }
449    const epdf& posterior() const {
450        return est;
451    }
452    void set_prior(const epdf *pri) {
453        const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
454        if (ep) {
455            bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
456                       num2str(no_bms()) + "BMs");
457            for (int i=0; i<no_bms(); i++) {
458                bm(i)->set_prior(ep->factor(i));
459            }
460        }
461    }
462
463    void validate() {
464        est.validate();
465        BM::validate();
466        // set links
467        bm_yt.set_length(no_bms());
468        bm_cond.set_length(no_bms());
469
470        //
471
472        for (int i=0; i<no_bms(); i++) {
473            yrv.add(bm(i)->_yrv());
474            rvc.add(bm(i)->_rvc());
475        }
476        rvc=rvc.subt(yrv);
477
478        dimy = yrv._dsize();
479        dimc = rvc._dsize();
480
481        for (int i=0; i<no_bms(); i++) {
482            bm_yt(i).connect(bm(i)->_yrv(), yrv);
483            bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
484        }
485    }
486    void bayes(const vec &dt, const vec &cond) {
487        ll=0;
488        for(int i=0; i<no_bms(); i++) {
489            bm_yt(i).update(dt);
490            bm_cond(i).update(dt,cond);
491            bm(i)->bayes(bm_yt(i), bm_cond(i));
492        }
493    }
494
[979]495};
496
[1064]497class ProdBM: public ProdBMBase {
498protected:
499    Array<shared_ptr<BM> > BMs;
500public:
501    virtual BM* bm(int i) {
502        return BMs(i).get();
503    }
504    virtual int no_bms() const {
505        return BMs.length();
506    }
507    void from_setting(const Setting &set) {
508        BM::from_setting(set);
509        UI::get(BMs,set,"BMs");
510    }
511    void to_setting(Setting &set) const {
512        BM::to_setting(set);
513        UI::save(BMs,set,"BMs");
514    }
515
[979]516};
517UIREGISTER(ProdBM);
518
[1064]519//! \brief class for on-line model selection
520class ModelComparator: public ProdBM {
521protected:
522    multiBM weights;
523public:
524    void bayes(const vec &yt, const vec &cond) {
525        vec w_nn(no_bms());
526        for (int i=0; i<no_bms(); i++) {
527            bm(i)->bayes(yt,cond);
528            w_nn(i) += bm(i)->_ll();
529        }
530        vec w=exp(w_nn-max(w_nn));
531        weights.bayes(w/sum(w));
532    }
533    void validate() {
534        ProdBM::validate();
535        weights.validate();
536    }
537    void from_setting(const Setting& set) {
538        ProdBM::from_setting(set);
539        UI::get(weights.frg, set, "frg",UI::optional);
540    }
541    void to_setting(Setting& set) const {
542        ProdBM::to_setting(set);
543        UI::save(weights.frg, set, "frg");
544    }
[1020]545};
546
[254]547}
[107]548#endif //MX_H
Note: See TracBrowser for help on using the browser.