emix.h
Go to the documentation of this file.00001 00013 #ifndef EMIX_H 00014 #define EMIX_H 00015 00016 #define LOG2 0.69314718055995 00017 00018 #include "../shared_ptr.h" 00019 #include "exp_family.h" 00020 00021 namespace bdm { 00022 00023 //this comes first because it is used inside emix! 00024 00038 class mratio: public pdf { 00039 protected: 00041 const epdf* nom; 00042 00044 shared_ptr<epdf> den; 00045 00047 bool destroynom; 00049 datalink_m2e dl; 00050 public: 00053 mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) { 00054 // adjust rv and rvc 00055 00056 set_rv ( rv ); 00057 dim = rv._dsize(); 00058 00059 rvc = nom0->_rv().subt ( rv ); 00060 dimc = rvc._dsize(); 00061 00062 //prepare data structures 00063 if ( copy ) { 00064 bdm_error ( "todo" ); 00065 // destroynom = true; 00066 } else { 00067 nom = nom0; 00068 destroynom = false; 00069 } 00070 bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" ); 00071 00072 // build denominator 00073 den = nom->marginal ( rvc ); 00074 dl.set_connection ( rv, rvc, nom0->_rv() ); 00075 }; 00076 00077 double evallogcond ( const vec &val, const vec &cond ) { 00078 double tmp; 00079 vec nom_val ( dimension() + dimc ); 00080 dl.pushup_cond ( nom_val, val, cond ); 00081 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) ); 00082 return tmp; 00083 } 00084 00086 virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0); 00087 00089 void ownnom() { 00090 destroynom = true; 00091 } 00093 ~mratio() { 00094 if ( destroynom ) { 00095 delete nom; 00096 } 00097 } 00098 00099 00100 private: 00101 // not implemented 00102 mratio ( const mratio & ); 00103 mratio &operator= ( const mratio & ); 00104 }; 00105 00106 class emix; //forward 00107 00109 class emix_base : public epdf { 00110 protected: 00112 vec &w; 00114 virtual const epdf * component(const int &i) const=0; 00115 00116 virtual int no_coms() const=0; 00117 00118 public: 00119 00120 emix_base(vec &w0): w(w0) {} 00121 00122 void validate (); 00123 00124 vec sample() const; 00125 00126 vec mean() const; 00127 00128 vec variance() const; 00129 00130 double evallog ( const vec &val ) const; 00131 00132 vec evallog_mat ( const mat &Val ) const; 00133 00135 mat evallog_coms ( const mat &Val ) const; 00136 00137 shared_ptr<epdf> marginal ( const RV &rv ) const; 00139 void marginal ( const RV &rv, emix &target ) const; 00140 shared_ptr<pdf> condition ( const RV &rv ) const; 00141 00142 //Access methods 00144 vec& _w() { 00145 return w; 00146 } 00147 00148 const vec& _w() const { 00149 return w; 00150 } 00152 const epdf* _com(int i) const { 00153 return component(i); 00154 } 00155 00156 }; 00157 00168 class emix : public emix_base { 00169 protected: 00171 vec weights; 00172 00174 Array<shared_ptr<epdf> > Coms; 00175 00176 public: 00178 emix ( ) : emix_base ( weights) { } 00179 00180 const epdf* component(const int &i) const { 00181 return Coms(i).get(); 00182 } 00183 void validate(); 00184 00185 00186 int no_coms() const { 00187 return Coms.length(); 00188 } 00189 00202 void from_setting ( const Setting &set ); 00203 00204 void to_setting (Setting &set) const; 00205 00206 void set_rv ( const RV &rv ) { 00207 epdf::set_rv ( rv ); 00208 for ( int i = 0; i < no_coms(); i++ ) { 00209 Coms( i )->set_rv ( rv ); 00210 } 00211 } 00212 00213 Array<shared_ptr<epdf> >& _Coms ( ) { 00214 return Coms; 00215 } 00216 }; 00217 SHAREDPTR ( emix ); 00218 UIREGISTER ( emix ); 00219 00220 00229 class mprod: public pdf { 00230 private: 00231 Array<shared_ptr<pdf> > pdfs; 00232 00234 Array<shared_ptr<datalink_m2m> > dls; 00235 00236 public: 00238 mprod() { } 00239 00242 mprod ( const Array<shared_ptr<pdf> > &mFacs ) { 00243 set_elements ( mFacs ); 00244 } 00246 void set_elements ( const Array<shared_ptr<pdf> > &mFacs ); 00247 00248 double evallogcond ( const vec &val, const vec &cond ); 00249 00250 vec evallogcond_mat ( const mat &Dt, const vec &cond ); 00251 00252 vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond ); 00253 00254 //TODO smarter... 00255 vec samplecond ( const vec &cond ) { 00257 vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() ); 00258 vec smpi; 00259 // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated! 00260 for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) { 00261 // generate contribution of this pdf 00262 smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) ); 00263 // copy contribution of this pdf into smp 00264 dls ( i )->pushup ( smp, smpi ); 00265 } 00266 return smp; 00267 } 00268 00277 void from_setting ( const Setting &set ) ; 00278 void to_setting (Setting &set) const; 00279 }; 00280 UIREGISTER ( mprod ); 00281 SHAREDPTR ( mprod ); 00282 00283 00285 class eprod_base: public epdf { 00286 protected: 00288 Array<datalink*> dls; 00290 public: 00291 virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL); 00293 virtual const int no_factors() const NOT_IMPLEMENTED(0); 00295 eprod_base () : dls (0) {}; 00296 eprod_base (const eprod_base &ep0) : dls (ep0.dls) {}; 00298 vec mean() const; 00299 00300 vec variance() const; 00301 00302 vec sample() const; 00303 00304 double evallog ( const vec &val ) const; 00305 00307 ~eprod_base() { 00308 for ( int i = 0; i < dls.length(); i++ ) { 00309 delete dls ( i ); 00310 } 00311 } 00312 void validate() { 00313 epdf::validate(); 00314 dls.set_length ( no_factors() ); 00315 00316 bool independent = true; 00317 dim = 0; 00318 rv = RV(); 00319 for ( int i = 0; i < no_factors(); i++ ) { 00320 independent = rv.add ( factor ( i )->_rv() ); 00321 dim += factor ( i )->dimension(); 00322 bdm_assert_debug ( independent, "eprod:: given components are not independent." ); 00323 }; 00324 00325 // 00326 int cumdim = 0; 00327 int dimi = 0; 00328 int i; 00329 for ( i = 0; i < no_factors(); i++ ) { 00330 if (!dls(i)){ 00331 dls ( i ) = new datalink; 00332 } 00333 if ( isnamed() ) { // rvs are complete 00334 dls ( i )->set_connection ( factor ( i )->_rv() , rv ); 00335 } else { //rvs are not reliable 00336 dimi = factor ( i )->dimension(); 00337 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) ); 00338 cumdim += dimi; 00339 } 00340 } 00341 00342 } 00343 }; 00344 00346 class eprod: public eprod_base { 00347 protected: 00348 Array<shared_ptr<epdf> > factors; 00349 public: 00350 const epdf* factor(int i) const { 00351 return factors(i).get(); 00352 } 00353 const int no_factors() const { 00354 return factors.length(); 00355 } 00356 void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) { 00357 factors = epdfs0; 00358 } 00359 00370 void from_setting(const Setting &set) { 00371 UI::get(factors,set,"pdfs",UI::compulsory); 00372 } 00373 }; 00374 UIREGISTER(eprod); 00375 00377 class eprod_internal: public eprod_base { 00378 protected: 00379 Array<epdf* > factors; 00380 const epdf* factor(int i) const { 00381 return factors(i); 00382 } 00383 const int no_factors() const { 00384 return factors.length(); 00385 } 00386 public: 00387 void set_parameters ( const Array<epdf *> &epdfs0) { 00388 factors = epdfs0; 00389 } 00390 }; 00391 00395 class mmix : public pdf { 00396 protected: 00398 Array<shared_ptr<pdf> > Coms; 00400 vec w; 00401 public: 00403 mmix() : Coms ( 0 ) { }; 00404 00405 double evallogcond ( const vec &dt, const vec &cond ) { 00406 double ll = 0.0; 00407 for ( int i = 0; i < Coms.length(); i++ ) { 00408 ll += Coms ( i )->evallogcond ( dt, cond ); 00409 } 00410 return ll; 00411 } 00412 00413 vec samplecond ( const vec &cond ); 00414 00418 00434 void from_setting ( const Setting &set ); 00435 void to_setting (Setting &set) const; 00436 virtual void validate(); 00437 }; 00438 SHAREDPTR ( mmix ); 00439 UIREGISTER ( mmix ); 00440 00441 00443 class ProdBMBase : public BM { 00444 protected : 00445 Array<vec_from_vec> bm_yt; 00446 Array<vec_from_2vec> bm_cond; 00447 00449 class eprod_bm : public eprod_base { 00450 ProdBMBase & pb; 00451 public : 00452 eprod_bm(ProdBMBase &pb0): pb(pb0) {} 00453 const epdf* factor(int i ) const { 00454 return &(pb.bm(i)->posterior()); 00455 } 00456 const int no_factors() const { 00457 return pb.no_bms(); 00458 } 00459 } est; 00460 public: 00461 ProdBMBase():BM(),est(*this) {} 00462 ProdBMBase(const ProdBMBase &p0):BM(p0),bm_yt(p0.bm_yt),bm_cond(p0.bm_cond),est(*this) { 00463 est.validate(); 00464 } 00465 virtual BM* bm(int i) NOT_IMPLEMENTED(NULL); 00466 virtual int no_bms() const { 00467 return 0; 00468 } 00469 const epdf& posterior() const { 00470 return est; 00471 } 00472 void set_prior(const epdf *pri) { 00473 const eprod_base* ep=dynamic_cast<const eprod_base*>(pri); 00474 if (ep) { 00475 bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+ 00476 num2str(no_bms()) + "BMs"); 00477 for (int i=0; i<no_bms(); i++) { 00478 bm(i)->set_prior(ep->factor(i)); 00479 } 00480 } 00481 } 00482 00483 void validate() { 00484 est.validate(); 00485 BM::validate(); 00486 // set links 00487 bm_yt.set_length(no_bms()); 00488 bm_cond.set_length(no_bms()); 00489 00490 // 00491 00492 for (int i=0; i<no_bms(); i++) { 00493 yrv.add(bm(i)->_yrv()); 00494 rvc.add(bm(i)->_rvc()); 00495 } 00496 rvc=rvc.subt(yrv); 00497 00498 dimy = yrv._dsize(); 00499 dimc = rvc._dsize(); 00500 00501 for (int i=0; i<no_bms(); i++) { 00502 bm_yt(i).set_length(bm(i)->dimensiony()); 00503 bm_yt(i).connect(bm(i)->_yrv(), yrv); 00504 bm_cond(i).set_length(bm(i)->dimensionc()); 00505 bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc); 00506 } 00507 } 00508 void bayes(const vec &dt, const vec &cond) { 00509 ll=0; 00510 for(int i=0; i<no_bms(); i++) { 00511 bm_yt(i).update(dt); 00512 bm_cond(i).update(dt,cond); 00513 bm(i)->bayes(bm_yt(i), bm_cond(i)); 00514 ll+=bm(i)->_ll(); 00515 } 00516 } 00517 vec samplepred( const vec &cond) { 00518 vec samp=zeros(dimy); 00519 00520 for(int i=0; i<no_bms(); i++) { 00521 bm_cond(i).update(samp,cond); 00522 vec yi=bm(i)->samplepred(bm_cond(i)); 00523 bm_yt(i)._dl().pushup(samp,yi); 00524 } 00525 return samp; 00526 } 00527 00528 }; 00529 00530 class ProdBM: public ProdBMBase { 00531 protected: 00532 Array<shared_ptr<BM> > BMs; 00533 public: 00534 ProdBM():ProdBMBase(),BMs(){}; 00535 ProdBM(const ProdBM &p0):ProdBMBase(p0),BMs(p0.BMs){est.validate();}; 00536 ProdBM* _copy() const {return new ProdBM(*this);} 00537 virtual BM* bm(int i) { 00538 return BMs(i).get(); 00539 } 00540 virtual int no_bms() const { 00541 return BMs.length(); 00542 } 00543 void from_setting(const Setting &set) { 00544 BM::from_setting(set); 00545 UI::get(BMs,set,"BMs"); 00546 } 00547 void to_setting(Setting &set) const { 00548 BM::to_setting(set); 00549 UI::save(BMs,set,"BMs"); 00550 } 00551 }; 00552 UIREGISTER(ProdBM); 00553 00555 class ModelComparator: public ProdBM { 00556 protected: 00557 multiBM weights; 00558 public: 00559 void bayes(const vec &yt, const vec &cond) { 00560 vec w_nn(no_bms()); 00561 for (int i=0; i<no_bms(); i++) { 00562 bm(i)->bayes(yt,cond); 00563 w_nn(i) += bm(i)->_ll(); 00564 } 00565 vec w=exp(w_nn-max(w_nn)); 00566 weights.bayes(w/sum(w)); 00567 } 00568 void validate() { 00569 ProdBM::validate(); 00570 weights.validate(); 00571 } 00572 00583 void from_setting(const Setting& set) { 00584 ProdBM::from_setting(set); 00585 UI::get(weights.frg, set, "frg",UI::optional); 00586 } 00587 00588 void to_setting(Setting& set) const { 00589 ProdBM::to_setting(set); 00590 UI::save(weights.frg, set, "frg"); 00591 } 00592 }; 00593 00594 } 00595 #endif //MX_H
Generated on 2 Dec 2013 for mixpp by
