00001
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015
00016 #define LOG2 0.69314718055995
00017
00018 #include "../shared_ptr.h"
00019 #include "exp_family.h"
00020
00021 namespace bdm {
00022
00023
00024
00037 class mratio: public mpdf {
00038 protected:
00040 const epdf* nom;
00042 epdf* den;
00044 bool destroynom;
00046 datalink_m2e dl;
00047 public:
00050 mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
00051
00052 rvc = nom0->_rv().subt ( rv );
00053 dimc = rvc._dsize();
00054 set_ep(shared_ptr<epdf>(new epdf));
00055 e()->set_parameters(rv._dsize());
00056 e()->set_rv(rv);
00057
00058
00059 if ( copy ) {it_error ( "todo" ); destroynom=true; }
00060 else { nom = nom0; destroynom = false; }
00061 it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );
00062
00063
00064 den = nom->marginal ( rvc );
00065 dl.set_connection(rv,rvc,nom0->_rv());
00066 };
00067 double evallogcond ( const vec &val, const vec &cond ) {
00068 double tmp;
00069 vec nom_val ( e()->dimension() + dimc );
00070 dl.pushup_cond ( nom_val,val,cond );
00071 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00072 it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
00073 return tmp;
00074 }
00076 void ownnom() {destroynom=true;}
00078 ~mratio() {delete den; if ( destroynom ) {delete nom;}}
00079 };
00080
00091 class emix : public epdf {
00092 protected:
00094 vec w;
00096 Array<epdf*> Coms;
00098 bool destroyComs;
00099 public:
00101 emix ( ) : epdf ( ) {};
00104 void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
00105
00106 vec sample() const;
00107 vec mean() const {
00108 int i; vec mu = zeros ( dim );
00109 for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
00110 return mu;
00111 }
00112 vec variance() const {
00113
00114 vec mom2 = zeros ( dim );
00115 for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
00116
00117 return mom2-pow ( mean(),2 );
00118 }
00119 double evallog ( const vec &val ) const {
00120 int i;
00121 double sum = 0.0;
00122 for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
00123 if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
00124 double tmp=log ( sum );
00125 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00126 return tmp;
00127 };
00128 vec evallog_m ( const mat &Val ) const {
00129 vec x=zeros ( Val.cols() );
00130 for ( int i = 0; i < w.length(); i++ ) {
00131 x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
00132 }
00133 return log ( x );
00134 };
00136 mat evallog_M ( const mat &Val ) const {
00137 mat X ( w.length(), Val.cols() );
00138 for ( int i = 0; i < w.length(); i++ ) {
00139 X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00140 }
00141 return X;
00142 };
00143
00144 emix* marginal ( const RV &rv ) const;
00145 mratio* condition ( const RV &rv ) const;
00146
00147
00149 vec& _w() {return w;}
00150 virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00152 void ownComs() {destroyComs=true;}
00153
00155 epdf* _Coms ( int i ) {return Coms ( i );}
00156 void set_rv(const RV &rv){
00157 epdf::set_rv(rv);
00158 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00159 }
00160 };
00161
00162
00167 class egiwmix : public egiw {
00168 protected:
00170 vec w;
00172 Array<egiw*> Coms;
00174 bool destroyComs;
00175 public:
00177 egiwmix ( ) : egiw ( ) {};
00178
00181 void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
00182
00184 vec mean() const;
00185
00187 vec sample() const;
00188
00190 vec variance() const;
00191
00192
00193 void mean_mat ( mat &M, mat&R ) const {};
00194 double evallog_nn ( const vec &val ) const {return 0;};
00195 double lognc () const {return 0;};
00196 emix* marginal ( const RV &rv ) const;
00197
00198
00200 vec& _w() {return w;}
00201 virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00203 void ownComs() {destroyComs=true;}
00204
00206 egiw* _Coms ( int i ) {return Coms ( i );}
00207
00208 void set_rv(const RV &rv){
00209 egiw::set_rv(rv);
00210 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00211 }
00212
00214 egiw* approx();
00215 };
00216
00225 class mprod: public compositepdf, public mpdf {
00226 protected:
00228 Array<epdf*> epdfs;
00230 Array<datalink_m2m*> dls;
00231
00233 shared_ptr<epdf> dummy;
00234
00235 public:
00238 mprod():dummy(new epdf()) { }
00239 mprod (Array<mpdf*> mFacs):
00240 dummy(new epdf()) {
00241 set_elements(mFacs);
00242 }
00243
00244 void set_elements(Array<mpdf*> mFacs , bool own=false) {
00245
00246 compositepdf::set_elements(mFacs,own);
00247 dls.set_size(mFacs.length());
00248 epdfs.set_size(mFacs.length());
00249
00250 set_ep(dummy);
00251 RV rv=getrv ( true );
00252 set_rv ( rv );
00253 dummy->set_parameters(rv._dsize());
00254 setrvc(e()->_rv(), rvc);
00255
00256 for ( int i = 0;i < mpdfs.length(); i++ ) {
00257 dls ( i ) = new datalink_m2m;
00258 dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
00259 }
00260
00261 for ( int i=0; i<mpdfs.length(); i++ ) {
00262 epdfs(i) = mpdfs(i)->e();
00263 }
00264 };
00265
00266 double evallogcond ( const vec &val, const vec &cond ) {
00267 int i;
00268 double res = 0.0;
00269 for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
00270
00271
00272
00273
00274
00275 res += mpdfs ( i )->evallogcond (
00276 dls ( i )->pushdown ( val ),
00277 dls ( i )->get_cond ( val, cond )
00278 );
00279 }
00280 return res;
00281 }
00282 vec evallogcond_m(const mat &Dt, const vec &cond) {
00283 vec tmp(Dt.cols());
00284 for(int i=0;i<Dt.cols(); i++){
00285 tmp(i) = evallogcond(Dt.get_col(i),cond);
00286 }
00287 return tmp;
00288 };
00289 vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {
00290 vec tmp(Dt.length());
00291 for(int i=0;i<Dt.length(); i++){
00292 tmp(i) = evallogcond(Dt(i),cond);
00293 }
00294 return tmp;
00295 };
00296
00297
00298
00299 vec samplecond ( const vec &cond ) {
00301 vec smp= std::numeric_limits<double>::infinity() * ones ( e()->dimension() );
00302 vec smpi;
00303
00304 for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
00305 if ( mpdfs ( i )->dimensionc() ) {
00306 mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) );
00307 }
00308 smpi = epdfs ( i )->sample();
00309
00310 dls ( i )->pushup ( smp, smpi );
00311
00312 }
00313 return smp;
00314 }
00315 mat samplecond ( const vec &cond, int N ) {
00316 mat Smp ( dimension(),N );
00317 for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
00318 return Smp;
00319 }
00320
00321 ~mprod() {};
00329 void from_setting(const Setting &set){
00330 Array<mpdf*> Atmp;
00331 UI::get(Atmp,set, "mpdfs", UI::compulsory);
00332 set_elements(Atmp,true);
00333 }
00334
00335 };
00336 UIREGISTER(mprod);
00337
00339 class eprod: public epdf {
00340 protected:
00342 Array<const epdf*> epdfs;
00344 Array<datalink*> dls;
00345 public:
00346 eprod () : epdfs ( 0 ),dls ( 0 ) {};
00347 void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
00348 epdfs=epdfs0;
00349 dls.set_length ( epdfs.length() );
00350
00351 bool independent=true;
00352 if ( named ) {
00353 for ( int i=0;i<epdfs.length();i++ ) {
00354 independent=rv.add ( epdfs ( i )->_rv() );
00355 it_assert_debug ( independent==true, "eprod:: given components are not independent." );
00356 }
00357 dim=rv._dsize();
00358 }
00359 else {
00360 dim =0; for ( int i=0;i<epdfs.length();i++ ) {
00361 dim+=epdfs ( i )->dimension();
00362 }
00363 }
00364
00365 int cumdim=0;
00366 int dimi=0;
00367 int i;
00368 for ( i=0;i<epdfs.length();i++ ) {
00369 dls ( i ) = new datalink;
00370 if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
00371 else {
00372 dimi = epdfs ( i )->dimension();
00373 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
00374 cumdim+=dimi;
00375 }
00376 }
00377 }
00378
00379 vec mean() const {
00380 vec tmp ( dim );
00381 for ( int i=0;i<epdfs.length();i++ ) {
00382 vec pom = epdfs ( i )->mean();
00383 dls ( i )->pushup ( tmp, pom );
00384 }
00385 return tmp;
00386 }
00387 vec variance() const {
00388 vec tmp ( dim );
00389 for ( int i=0;i<epdfs.length();i++ ) {
00390 vec pom = epdfs ( i )->mean();
00391 dls ( i )->pushup ( tmp, pow ( pom,2 ) );
00392 }
00393 return tmp-pow ( mean(),2 );
00394 }
00395 vec sample() const {
00396 vec tmp ( dim );
00397 for ( int i=0;i<epdfs.length();i++ ) {
00398 vec pom = epdfs ( i )->sample();
00399 dls ( i )->pushup ( tmp, pom );
00400 }
00401 return tmp;
00402 }
00403 double evallog ( const vec &val ) const {
00404 double tmp=0;
00405 for ( int i=0;i<epdfs.length();i++ ) {
00406 tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00407 }
00408 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00409 return tmp;
00410 }
00412 const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
00413
00415 ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
00416 };
00417
00418
00422 class mmix : public mpdf {
00423 protected:
00425 Array<mpdf*> Coms;
00426
00428 shared_ptr<emix> iepdf;
00429
00430 public:
00432 mmix():iepdf(new emix()) {
00433 set_ep(iepdf);
00434 }
00435
00437 void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
00438 Array<epdf*> Eps ( Coms.length() );
00439
00440 for ( int i = 0;i < Coms.length();i++ ) {
00441 Eps(i) = Coms(i)->e();
00442 }
00443
00444 iepdf->set_parameters(w, Eps);
00445 }
00446
00447 void condition ( const vec &cond ) {
00448 for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
00449 };
00450 };
00451
00452 }
00453 #endif //MX_H