mixpp: emix.h Source File

emix.h

Go to the documentation of this file.
00001
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015 
00016 #define LOG2  0.69314718055995
00017 
00018 #include "../shared_ptr.h"
00019 #include "exp_family.h"
00020
00021 namespace bdm {
00022
00023 //this comes first because it is used inside emix!
00024
00038 class mratio: public pdf {
00039 protected:
00041     const epdf* nom;
00042
00044     shared_ptr<epdf> den;
00045
00047     bool destroynom;
00049     datalink_m2e dl;
00050 public:
00053     mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
00054         // adjust rv and rvc
00055
00056         set_rv ( rv );
00057         dim = rv._dsize();
00058
00059         rvc = nom0->_rv().subt ( rv );
00060         dimc = rvc._dsize();
00061
00062         //prepare data structures
00063         if ( copy ) {
00064             bdm_error ( "todo" );
00065             // destroynom = true;
00066         } else {
00067             nom = nom0;
00068             destroynom = false;
00069         }
00070         bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
00071
00072         // build denominator
00073         den = nom->marginal ( rvc );
00074         dl.set_connection ( rv, rvc, nom0->_rv() );
00075     };
00076
00077     double evallogcond ( const vec &val, const vec &cond ) {
00078         double tmp;
00079         vec nom_val ( dimension() + dimc );
00080         dl.pushup_cond ( nom_val, val, cond );
00081         tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00082         return tmp;
00083     }
00084
00086     virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);
00087
00089     void ownnom() {
00090         destroynom = true;
00091     }
00093     ~mratio() {
00094         if ( destroynom ) {
00095             delete nom;
00096         }
00097     }
00098
00099
00100 private:
00101     // not implemented
00102     mratio ( const mratio & );
00103     mratio &operator= ( const mratio & );
00104 };
00105
00106 class emix; //forward
00107
00109 class emix_base : public epdf {
00110 protected:
00112     vec &w;
00114     virtual const epdf * component(const int &i) const=0;
00115
00116     virtual int no_coms() const=0;
00117
00118 public:
00119
00120     emix_base(vec &w0): w(w0) {}
00121
00122     void validate ();
00123
00124     vec sample() const;
00125
00126     vec mean() const;
00127
00128     vec variance() const;
00129
00130     double evallog ( const vec &val ) const;
00131
00132     vec evallog_mat ( const mat &Val ) const;
00133
00135     mat evallog_coms ( const mat &Val ) const;
00136
00137     shared_ptr<epdf> marginal ( const RV &rv ) const;
00139     void marginal ( const RV &rv, emix &target ) const;
00140     shared_ptr<pdf> condition ( const RV &rv ) const;
00141
00142     //Access methods
00144     vec& _w() {
00145         return w;
00146     }
00147
00148     const vec& _w() const {
00149         return w;
00150     }
00152     const epdf* _com(int i) const {
00153         return component(i);
00154     }
00155
00156 };
00157
00168 class emix : public emix_base {
00169 protected:
00171     vec weights;
00172
00174     Array<shared_ptr<epdf> > Coms;
00175
00176 public:
00178     emix ( ) : emix_base ( weights) { }
00179
00180     const epdf* component(const int &i) const {
00181         return Coms(i).get();
00182     }
00183     void validate();
00184
00185
00186     int no_coms() const {
00187         return Coms.length();
00188     }
00189
00202     void from_setting ( const Setting &set );
00203
00204     void to_setting  (Setting  &set) const;
00205
00206     void set_rv ( const RV &rv ) {
00207         epdf::set_rv ( rv );
00208         for ( int i = 0; i < no_coms(); i++ ) {
00209             Coms( i )->set_rv ( rv );
00210         }
00211     }
00212
00213     Array<shared_ptr<epdf> >& _Coms ( ) {
00214         return Coms;
00215     }
00216 };
00217 SHAREDPTR ( emix );
00218 UIREGISTER ( emix );
00219
00220
00229 class mprod: public pdf {
00230 private:
00231     Array<shared_ptr<pdf> > pdfs;
00232
00234     Array<shared_ptr<datalink_m2m> > dls;
00235
00236 public:
00238     mprod() { }
00239
00242     mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
00243         set_elements ( mFacs );
00244     }
00246     void set_elements ( const Array<shared_ptr<pdf> > &mFacs );
00247
00248     double evallogcond ( const vec &val, const vec &cond );
00249
00250     vec evallogcond_mat ( const mat &Dt, const vec &cond );
00251
00252     vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );
00253
00254     //TODO smarter...
00255     vec samplecond ( const vec &cond ) {
00257         vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
00258         vec smpi;
00259         // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
00260         for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
00261             // generate contribution of this pdf
00262             smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
00263             // copy contribution of this pdf into smp
00264             dls ( i )->pushup ( smp, smpi );
00265         }
00266         return smp;
00267     }
00268
00277     void from_setting ( const Setting &set ) ;
00278     void to_setting  (Setting  &set) const;
00279 };
00280 UIREGISTER ( mprod );
00281 SHAREDPTR ( mprod );
00282
00283
00285 class eprod_base: public epdf {
00286 protected:
00288     Array<datalink*> dls;
00290 public:
00291     virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL);
00293     virtual const int no_factors() const NOT_IMPLEMENTED(0);
00295         eprod_base () :  dls (0) {};
00296         eprod_base (const eprod_base &ep0) :  dls (ep0.dls) {};
00298     vec mean() const;
00299
00300     vec variance() const;
00301
00302     vec sample() const;
00303
00304     double evallog ( const vec &val ) const;
00305
00307     ~eprod_base() {
00308         for ( int i = 0; i < dls.length(); i++ ) {
00309             delete dls ( i );
00310         }
00311     }
00312     void validate() {
00313         epdf::validate();
00314         dls.set_length ( no_factors() );
00315
00316         bool independent = true;
00317         dim = 0;
00318                 rv = RV();
00319         for ( int i = 0; i < no_factors(); i++ ) {
00320             independent = rv.add ( factor ( i )->_rv() );
00321             dim += factor ( i )->dimension();
00322             bdm_assert_debug ( independent, "eprod:: given components are not independent." );
00323         };
00324
00325         //
00326         int cumdim = 0;
00327         int dimi = 0;
00328         int i;
00329         for ( i = 0; i < no_factors(); i++ ) {
00330                         if (!dls(i)){
00331                                 dls ( i ) = new datalink;
00332                         }
00333             if ( isnamed() ) { // rvs are complete
00334                 dls ( i )->set_connection ( factor ( i )->_rv() , rv );
00335             } else { //rvs are not reliable
00336                 dimi = factor ( i )->dimension();
00337                 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
00338                 cumdim += dimi;
00339             }
00340         }
00341
00342     }
00343 };
00344
00346 class eprod: public eprod_base {
00347 protected:
00348     Array<shared_ptr<epdf> > factors;
00349 public:
00350     const epdf* factor(int i) const {
00351         return factors(i).get();
00352     }
00353     const int no_factors() const {
00354         return factors.length();
00355     }
00356     void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
00357         factors = epdfs0;
00358     }
00359
00370     void from_setting(const Setting &set) {
00371         UI::get(factors,set,"pdfs",UI::compulsory);
00372     }
00373 };
00374 UIREGISTER(eprod);
00375
00377 class eprod_internal: public eprod_base {
00378 protected:
00379     Array<epdf* > factors;
00380     const epdf* factor(int i) const {
00381         return factors(i);
00382     }
00383     const int no_factors() const {
00384         return factors.length();
00385     }
00386 public:
00387     void set_parameters ( const Array<epdf *> &epdfs0) {
00388         factors = epdfs0;
00389     }
00390 };
00391
00395 class mmix : public pdf {
00396 protected:
00398     Array<shared_ptr<pdf> > Coms;
00400     vec w;
00401 public:
00403     mmix() : Coms ( 0 ) { };
00404
00405     double evallogcond ( const vec &dt, const vec &cond ) {
00406         double ll = 0.0;
00407         for ( int i = 0; i < Coms.length(); i++ ) {
00408             ll += Coms ( i )->evallogcond ( dt, cond );
00409         }
00410         return ll;
00411     }
00412
00413     vec samplecond ( const vec &cond );
00414
00418
00434     void from_setting ( const Setting &set );
00435     void     to_setting  (Setting  &set) const;
00436     virtual void validate();
00437 };
00438 SHAREDPTR ( mmix );
00439 UIREGISTER ( mmix );
00440
00441
00443 class ProdBMBase : public BM {
00444 protected :
00445     Array<vec_from_vec> bm_yt;
00446     Array<vec_from_2vec> bm_cond;
00447
00449     class eprod_bm : public eprod_base {
00450         ProdBMBase & pb;
00451     public :
00452                 eprod_bm(ProdBMBase &pb0): pb(pb0) {}
00453                 const epdf* factor(int i ) const {
00454             return &(pb.bm(i)->posterior());
00455         }
00456         const int no_factors() const {
00457             return pb.no_bms();
00458         }
00459     } est;
00460 public:
00461         ProdBMBase():BM(),est(*this) {}
00462         ProdBMBase(const ProdBMBase &p0):BM(p0),bm_yt(p0.bm_yt),bm_cond(p0.bm_cond),est(*this) {
00463                 est.validate();
00464         }
00465         virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
00466     virtual int no_bms() const {
00467         return 0;
00468     }
00469     const epdf& posterior() const {
00470         return est;
00471     }
00472     void set_prior(const epdf *pri) {
00473         const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
00474         if (ep) {
00475             bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
00476                        num2str(no_bms()) + "BMs");
00477             for (int i=0; i<no_bms(); i++) {
00478                 bm(i)->set_prior(ep->factor(i));
00479             }
00480         }
00481     }
00482
00483     void validate() {
00484         est.validate();
00485         BM::validate();
00486         // set links
00487         bm_yt.set_length(no_bms());
00488         bm_cond.set_length(no_bms());
00489
00490         //
00491
00492         for (int i=0; i<no_bms(); i++) {
00493             yrv.add(bm(i)->_yrv());
00494             rvc.add(bm(i)->_rvc());
00495         }
00496         rvc=rvc.subt(yrv);
00497
00498         dimy = yrv._dsize();
00499         dimc = rvc._dsize();
00500
00501         for (int i=0; i<no_bms(); i++) {
00502                         bm_yt(i).set_length(bm(i)->dimensiony());
00503             bm_yt(i).connect(bm(i)->_yrv(), yrv);
00504                         bm_cond(i).set_length(bm(i)->dimensionc());
00505                         bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
00506         }
00507     }
00508     void bayes(const vec &dt, const vec &cond) {
00509         ll=0;
00510         for(int i=0; i<no_bms(); i++) {
00511             bm_yt(i).update(dt);
00512             bm_cond(i).update(dt,cond);
00513             bm(i)->bayes(bm_yt(i), bm_cond(i));
00514                         ll+=bm(i)->_ll();
00515         }
00516     }
00517     vec samplepred( const vec &cond) {
00518                 vec samp=zeros(dimy);
00519
00520                 for(int i=0; i<no_bms(); i++) {
00521                         bm_cond(i).update(samp,cond);
00522                         vec yi=bm(i)->samplepred(bm_cond(i));
00523                         bm_yt(i)._dl().pushup(samp,yi);
00524                 }
00525                 return samp;
00526         }
00527
00528 };
00529
00530 class ProdBM: public ProdBMBase {
00531 protected:
00532     Array<shared_ptr<BM> > BMs;
00533 public:
00534         ProdBM():ProdBMBase(),BMs(){};
00535         ProdBM(const ProdBM &p0):ProdBMBase(p0),BMs(p0.BMs){est.validate();};
00536     ProdBM* _copy() const {return new ProdBM(*this);}
00537     virtual BM* bm(int i) {
00538         return BMs(i).get();
00539     }
00540     virtual int no_bms() const {
00541         return BMs.length();
00542     }
00543     void from_setting(const Setting &set) {
00544         BM::from_setting(set);
00545         UI::get(BMs,set,"BMs");
00546     }
00547     void to_setting(Setting &set) const {
00548         BM::to_setting(set);
00549         UI::save(BMs,set,"BMs");
00550     }
00551 };
00552 UIREGISTER(ProdBM);
00553
00555 class ModelComparator: public ProdBM {
00556 protected:
00557     multiBM weights;
00558 public:
00559     void bayes(const vec &yt, const vec &cond) {
00560         vec w_nn(no_bms());
00561         for (int i=0; i<no_bms(); i++) {
00562             bm(i)->bayes(yt,cond);
00563             w_nn(i) += bm(i)->_ll();
00564         }
00565         vec w=exp(w_nn-max(w_nn));
00566         weights.bayes(w/sum(w));
00567     }
00568     void validate() {
00569         ProdBM::validate();
00570         weights.validate();
00571     }
00572
00583     void from_setting(const Setting& set) {
00584         ProdBM::from_setting(set);
00585         UI::get(weights.frg, set, "frg",UI::optional);
00586     }
00587
00588     void to_setting(Setting& set) const {
00589         ProdBM::to_setting(set);
00590         UI::save(weights.frg, set, "frg");
00591     }
00592 };
00593
00594 }
00595 #endif //MX_H

Generated on 2 Dec 2013 for mixpp by  doxygen 1.4.7