00001
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015
00016 #define LOG2 0.69314718055995
00017
00018 #include "exp_family.h"
00019
00020 namespace bdm {
00021
00022
00023
00036 class mratio: public mpdf {
00037 protected:
00039 const epdf* nom;
00041 epdf* den;
00043 bool destroynom;
00045 datalink_m2e dl;
00046 public:
00049 mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
00050
00051 rvc = nom0->_rv().subt ( rv );
00052 dimc = rvc._dsize();
00053 ep = new epdf;
00054 ep->set_parameters(rv._dsize());
00055 ep->set_rv(rv);
00056
00057
00058 if ( copy ) {it_error ( "todo" ); destroynom=true; }
00059 else { nom = nom0; destroynom = false; }
00060 it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );
00061
00062
00063 den = nom->marginal ( rvc );
00064 dl.set_connection(rv,rvc,nom0->_rv());
00065 };
00066 double evallogcond ( const vec &val, const vec &cond ) {
00067 double tmp;
00068 vec nom_val ( ep->dimension() + dimc );
00069 dl.pushup_cond ( nom_val,val,cond );
00070 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00071 it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
00072 return tmp;
00073 }
00075 void ownnom() {destroynom=true;}
00077 ~mratio() {delete den; if ( destroynom ) {delete nom;}}
00078 };
00079
00090 class emix : public epdf {
00091 protected:
00093 vec w;
00095 Array<epdf*> Coms;
00097 bool destroyComs;
00098 public:
00100 emix ( ) : epdf ( ) {};
00103 void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
00104
00105 vec sample() const;
00106 vec mean() const {
00107 int i; vec mu = zeros ( dim );
00108 for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
00109 return mu;
00110 }
00111 vec variance() const {
00112
00113 vec mom2 = zeros ( dim );
00114 for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
00115
00116 return mom2-pow ( mean(),2 );
00117 }
00118 double evallog ( const vec &val ) const {
00119 int i;
00120 double sum = 0.0;
00121 for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
00122 if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
00123 double tmp=log ( sum );
00124 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00125 return tmp;
00126 };
00127 vec evallog_m ( const mat &Val ) const {
00128 vec x=zeros ( Val.cols() );
00129 for ( int i = 0; i < w.length(); i++ ) {
00130 x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
00131 }
00132 return log ( x );
00133 };
00135 mat evallog_M ( const mat &Val ) const {
00136 mat X ( w.length(), Val.cols() );
00137 for ( int i = 0; i < w.length(); i++ ) {
00138 X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00139 }
00140 return X;
00141 };
00142
00143 emix* marginal ( const RV &rv ) const;
00144 mratio* condition ( const RV &rv ) const;
00145
00146
00148 vec& _w() {return w;}
00149 virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00151 void ownComs() {destroyComs=true;}
00152
00154 epdf* _Coms ( int i ) {return Coms ( i );}
00155 void set_rv(const RV &rv){
00156 epdf::set_rv(rv);
00157 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00158 }
00159 };
00160
00161
00166 class egiwmix : public egiw {
00167 protected:
00169 vec w;
00171 Array<egiw*> Coms;
00173 bool destroyComs;
00174 public:
00176 egiwmix ( ) : egiw ( ) {};
00177
00180 void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
00181
00183 vec mean() const;
00184
00186 vec sample() const;
00187
00189 vec variance() const;
00190
00191
00192 void mean_mat ( mat &M, mat&R ) const {};
00193 double evallog_nn ( const vec &val ) const {return 0;};
00194 double lognc () const {return 0;};
00195 emix* marginal ( const RV &rv ) const;
00196
00197
00199 vec& _w() {return w;}
00200 virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00202 void ownComs() {destroyComs=true;}
00203
00205 egiw* _Coms ( int i ) {return Coms ( i );}
00206
00207 void set_rv(const RV &rv){
00208 egiw::set_rv(rv);
00209 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00210 }
00211
00213 egiw* approx();
00214 };
00215
00224 class mprod: public compositepdf, public mpdf {
00225 protected:
00227 Array<epdf*> epdfs;
00229 Array<datalink_m2m*> dls;
00231 epdf dummy;
00232 public:
00235 mprod (){};
00236 mprod (Array<mpdf*> mFacs ){set_elements( mFacs );};
00237 void set_elements(Array<mpdf*> mFacs , bool own=false) {
00238
00239 compositepdf::set_elements(mFacs,own);
00240 dls.set_size(mFacs.length());
00241 epdfs.set_size(mFacs.length());
00242
00243 ep=&dummy;
00244 RV rv=getrv ( true );
00245 set_rv ( rv );
00246 dummy.set_parameters ( rv._dsize() );
00247 setrvc ( ep->_rv(),rvc );
00248
00249 for ( int i = 0;i < mpdfs.length(); i++ ) {
00250 dls ( i ) = new datalink_m2m;
00251 dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
00252 }
00253
00254 for ( int i=0; i<mpdfs.length(); i++ ) {
00255 epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
00256 }
00257 };
00258
00259 double evallogcond ( const vec &val, const vec &cond ) {
00260 int i;
00261 double res = 0.0;
00262 for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
00263
00264
00265
00266
00267
00268 res += mpdfs ( i )->evallogcond (
00269 dls ( i )->pushdown ( val ),
00270 dls ( i )->get_cond ( val, cond )
00271 );
00272 }
00273 return res;
00274 }
00275 vec evallogcond_m(const mat &Dt, const vec &cond) {
00276 vec tmp(Dt.cols());
00277 for(int i=0;i<Dt.cols(); i++){
00278 tmp(i) = evallogcond(Dt.get_col(i),cond);
00279 }
00280 return tmp;
00281 };
00282 vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {
00283 vec tmp(Dt.length());
00284 for(int i=0;i<Dt.length(); i++){
00285 tmp(i) = evallogcond(Dt(i),cond);
00286 }
00287 return tmp;
00288 };
00289
00290
00291
00292 vec samplecond ( const vec &cond ) {
00294 vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
00295 vec smpi;
00296
00297 for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
00298 if ( mpdfs ( i )->dimensionc() ) {
00299 mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) );
00300 }
00301 smpi = epdfs ( i )->sample();
00302
00303 dls ( i )->pushup ( smp, smpi );
00304
00305 }
00306 return smp;
00307 }
00308 mat samplecond ( const vec &cond, int N ) {
00309 mat Smp ( dimension(),N );
00310 for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
00311 return Smp;
00312 }
00313
00314 ~mprod() {};
00322 void from_setting(const Setting &set){
00323 Array<mpdf*> Atmp;
00324 UI::get(Atmp,set, "mpdfs");
00325 set_elements(Atmp,true);
00326 }
00327
00328 };
00329 UIREGISTER(mprod);
00330
00332 class eprod: public epdf {
00333 protected:
00335 Array<const epdf*> epdfs;
00337 Array<datalink*> dls;
00338 public:
00339 eprod () : epdfs ( 0 ),dls ( 0 ) {};
00340 void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
00341 epdfs=epdfs0;
00342 dls.set_length ( epdfs.length() );
00343
00344 bool independent=true;
00345 if ( named ) {
00346 for ( int i=0;i<epdfs.length();i++ ) {
00347 independent=rv.add ( epdfs ( i )->_rv() );
00348 it_assert_debug ( independent==true, "eprod:: given components are not independent." );
00349 }
00350 dim=rv._dsize();
00351 }
00352 else {
00353 dim =0; for ( int i=0;i<epdfs.length();i++ ) {
00354 dim+=epdfs ( i )->dimension();
00355 }
00356 }
00357
00358 int cumdim=0;
00359 int dimi=0;
00360 int i;
00361 for ( i=0;i<epdfs.length();i++ ) {
00362 dls ( i ) = new datalink;
00363 if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
00364 else {
00365 dimi = epdfs ( i )->dimension();
00366 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
00367 cumdim+=dimi;
00368 }
00369 }
00370 }
00371
00372 vec mean() const {
00373 vec tmp ( dim );
00374 for ( int i=0;i<epdfs.length();i++ ) {
00375 vec pom = epdfs ( i )->mean();
00376 dls ( i )->pushup ( tmp, pom );
00377 }
00378 return tmp;
00379 }
00380 vec variance() const {
00381 vec tmp ( dim );
00382 for ( int i=0;i<epdfs.length();i++ ) {
00383 vec pom = epdfs ( i )->mean();
00384 dls ( i )->pushup ( tmp, pow ( pom,2 ) );
00385 }
00386 return tmp-pow ( mean(),2 );
00387 }
00388 vec sample() const {
00389 vec tmp ( dim );
00390 for ( int i=0;i<epdfs.length();i++ ) {
00391 vec pom = epdfs ( i )->sample();
00392 dls ( i )->pushup ( tmp, pom );
00393 }
00394 return tmp;
00395 }
00396 double evallog ( const vec &val ) const {
00397 double tmp=0;
00398 for ( int i=0;i<epdfs.length();i++ ) {
00399 tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00400 }
00401 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00402 return tmp;
00403 }
00405 const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
00406
00408 ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
00409 };
00410
00411
00415 class mmix : public mpdf {
00416 protected:
00418 Array<mpdf*> Coms;
00420 emix Epdf;
00421 public:
00423 mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
00425 void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
00426 Array<epdf*> Eps ( Coms.length() );
00427
00428 for ( int i = 0;i < Coms.length();i++ ) {
00429 Eps ( i ) = & ( Coms ( i )->_epdf() );
00430 }
00431 Epdf.set_parameters ( w, Eps );
00432 };
00433
00434 void condition ( const vec &cond ) {
00435 for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
00436 };
00437 };
00438
00439 }
00440 #endif //MX_H