00001
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015
00016 #define LOG2 0.69314718055995
00017
00018 #include "../shared_ptr.h"
00019 #include "exp_family.h"
00020
00021 namespace bdm {
00022
00023
00024
00037 class mratio: public mpdf {
00038 protected:
00040 const epdf* nom;
00041
00043 shared_ptr<epdf> den;
00044
00046 bool destroynom;
00048 datalink_m2e dl;
00050 epdf iepdf;
00051 public:
00054 mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
00055
00056 rvc = nom0->_rv().subt ( rv );
00057 dimc = rvc._dsize();
00058 set_ep ( iepdf );
00059 iepdf.set_parameters ( rv._dsize() );
00060 iepdf.set_rv ( rv );
00061
00062
00063 if ( copy ) {
00064 bdm_error ( "todo" );
00065
00066 } else {
00067 nom = nom0;
00068 destroynom = false;
00069 }
00070 bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
00071
00072
00073 den = nom->marginal ( rvc );
00074 dl.set_connection ( rv, rvc, nom0->_rv() );
00075 };
00076 double evallogcond ( const vec &val, const vec &cond ) {
00077 double tmp;
00078 vec nom_val ( dimension() + dimc );
00079 dl.pushup_cond ( nom_val, val, cond );
00080 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00081 return tmp;
00082 }
00084 void ownnom() {
00085 destroynom = true;
00086 }
00088 ~mratio() {
00089 if ( destroynom ) {
00090 delete nom;
00091 }
00092 }
00093
00094 private:
00095
00096 mratio ( const mratio & );
00097 mratio &operator=( const mratio & );
00098 };
00099
00110 class emix : public epdf {
00111 protected:
00113 vec w;
00114
00116 Array<shared_ptr<epdf> > Coms;
00117
00118 public:
00120 emix ( ) : epdf ( ) { }
00121
00128 void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
00129
00130 vec sample() const;
00131 vec mean() const {
00132 int i;
00133 vec mu = zeros ( dim );
00134 for ( i = 0; i < w.length(); i++ ) {
00135 mu += w ( i ) * Coms ( i )->mean();
00136 }
00137 return mu;
00138 }
00139 vec variance() const {
00140
00141 vec mom2 = zeros ( dim );
00142 for ( int i = 0; i < w.length(); i++ ) {
00143 mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
00144 }
00145
00146 return mom2 - pow ( mean(), 2 );
00147 }
00148 double evallog ( const vec &val ) const {
00149 int i;
00150 double sum = 0.0;
00151 for ( i = 0; i < w.length(); i++ ) {
00152 sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
00153 }
00154 if ( sum == 0.0 ) {
00155 sum = std::numeric_limits<double>::epsilon();
00156 }
00157 double tmp = log ( sum );
00158 bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
00159 return tmp;
00160 };
00161 vec evallog_m ( const mat &Val ) const {
00162 vec x = zeros ( Val.cols() );
00163 for ( int i = 0; i < w.length(); i++ ) {
00164 x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
00165 }
00166 return log ( x );
00167 };
00169 mat evallog_M ( const mat &Val ) const {
00170 mat X ( w.length(), Val.cols() );
00171 for ( int i = 0; i < w.length(); i++ ) {
00172 X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00173 }
00174 return X;
00175 };
00176
00177 shared_ptr<epdf> marginal ( const RV &rv ) const;
00179 void marginal ( const RV &rv, emix &target ) const;
00180 shared_ptr<mpdf> condition ( const RV &rv ) const;
00181
00182
00184 vec& _w() {
00185 return w;
00186 }
00187
00189 shared_ptr<epdf> _Coms ( int i ) {
00190 return Coms ( i );
00191 }
00192
00193 void set_rv ( const RV &rv ) {
00194 epdf::set_rv ( rv );
00195 for ( int i = 0; i < Coms.length(); i++ ) {
00196 Coms ( i )->set_rv ( rv );
00197 }
00198 }
00199 };
00200 SHAREDPTR( emix );
00201
00206 class egiwmix : public egiw {
00207 protected:
00209 vec w;
00211 Array<egiw*> Coms;
00213 bool destroyComs;
00214 public:
00216 egiwmix ( ) : egiw ( ) {};
00217
00220 void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
00221
00223 vec mean() const;
00224
00226 vec sample() const;
00227
00229 vec variance() const;
00230
00231
00232 void mean_mat ( mat &M, mat&R ) const {};
00233 double evallog_nn ( const vec &val ) const {
00234 return 0;
00235 };
00236 double lognc () const {
00237 return 0;
00238 }
00239
00240 shared_ptr<epdf> marginal ( const RV &rv ) const;
00241 void marginal ( const RV &rv, emix &target ) const;
00242
00243
00245 vec& _w() {
00246 return w;
00247 }
00248 virtual ~egiwmix() {
00249 if ( destroyComs ) {
00250 for ( int i = 0; i < Coms.length(); i++ ) {
00251 delete Coms ( i );
00252 }
00253 }
00254 }
00256 void ownComs() {
00257 destroyComs = true;
00258 }
00259
00261 egiw* _Coms ( int i ) {
00262 return Coms ( i );
00263 }
00264
00265 void set_rv ( const RV &rv ) {
00266 egiw::set_rv ( rv );
00267 for ( int i = 0; i < Coms.length(); i++ ) {
00268 Coms ( i )->set_rv ( rv );
00269 }
00270 }
00271
00273 egiw* approx();
00274 };
00275
00284 class mprod: public mpdf {
00285 private:
00286 Array<shared_ptr<mpdf> > mpdfs;
00287
00289 Array<shared_ptr<datalink_m2m> > dls;
00290
00291 protected:
00293 epdf iepdf;
00294
00295 public:
00297 mprod() { }
00298
00301 mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
00302 set_elements ( mFacs );
00303 }
00305 void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
00306
00307 double evallogcond ( const vec &val, const vec &cond ) {
00308 int i;
00309 double res = 0.0;
00310 for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
00311
00312
00313
00314
00315
00316 res += mpdfs ( i )->evallogcond (
00317 dls ( i )->pushdown ( val ),
00318 dls ( i )->get_cond ( val, cond )
00319 );
00320 }
00321 return res;
00322 }
00323 vec evallogcond_m ( const mat &Dt, const vec &cond ) {
00324 vec tmp ( Dt.cols() );
00325 for ( int i = 0; i < Dt.cols(); i++ ) {
00326 tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
00327 }
00328 return tmp;
00329 };
00330 vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
00331 vec tmp ( Dt.length() );
00332 for ( int i = 0; i < Dt.length(); i++ ) {
00333 tmp ( i ) = evallogcond ( Dt ( i ), cond );
00334 }
00335 return tmp;
00336 };
00337
00338
00339
00340 vec samplecond ( const vec &cond ) {
00342 vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
00343 vec smpi;
00344
00345 for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
00346
00347 smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));
00348
00349 dls ( i )->pushup ( smp, smpi );
00350 }
00351 return smp;
00352 }
00353
00361 void from_setting ( const Setting &set ) {
00362 Array<shared_ptr<mpdf> > atmp;
00363 UI::get ( atmp, set, "mpdfs", UI::compulsory );
00364 set_elements ( atmp );
00365 }
00366 };
00367 UIREGISTER ( mprod );
00368 SHAREDPTR ( mprod );
00369
00371 class eprod: public epdf {
00372 protected:
00374 Array<const epdf*> epdfs;
00376 Array<datalink*> dls;
00377 public:
00379 eprod () : epdfs ( 0 ), dls ( 0 ) {};
00381 void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
00382 epdfs = epdfs0;
00383 dls.set_length ( epdfs.length() );
00384
00385 bool independent = true;
00386 if ( named ) {
00387 for ( int i = 0; i < epdfs.length(); i++ ) {
00388 independent = rv.add ( epdfs ( i )->_rv() );
00389 bdm_assert_debug ( independent, "eprod:: given components are not independent." );
00390 }
00391 dim = rv._dsize();
00392 } else {
00393 dim = 0;
00394 for ( int i = 0; i < epdfs.length(); i++ ) {
00395 dim += epdfs ( i )->dimension();
00396 }
00397 }
00398
00399 int cumdim = 0;
00400 int dimi = 0;
00401 int i;
00402 for ( i = 0; i < epdfs.length(); i++ ) {
00403 dls ( i ) = new datalink;
00404 if ( named ) {
00405 dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
00406 } else {
00407 dimi = epdfs ( i )->dimension();
00408 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
00409 cumdim += dimi;
00410 }
00411 }
00412 }
00413
00414 vec mean() const {
00415 vec tmp ( dim );
00416 for ( int i = 0; i < epdfs.length(); i++ ) {
00417 vec pom = epdfs ( i )->mean();
00418 dls ( i )->pushup ( tmp, pom );
00419 }
00420 return tmp;
00421 }
00422 vec variance() const {
00423 vec tmp ( dim );
00424 for ( int i = 0; i < epdfs.length(); i++ ) {
00425 vec pom = epdfs ( i )->mean();
00426 dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
00427 }
00428 return tmp - pow ( mean(), 2 );
00429 }
00430 vec sample() const {
00431 vec tmp ( dim );
00432 for ( int i = 0; i < epdfs.length(); i++ ) {
00433 vec pom = epdfs ( i )->sample();
00434 dls ( i )->pushup ( tmp, pom );
00435 }
00436 return tmp;
00437 }
00438 double evallog ( const vec &val ) const {
00439 double tmp = 0;
00440 for ( int i = 0; i < epdfs.length(); i++ ) {
00441 tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00442 }
00443 bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
00444 return tmp;
00445 }
00447 const epdf* operator () ( int i ) const {
00448 bdm_assert_debug ( i < epdfs.length(), "wrong index" );
00449 return epdfs ( i );
00450 }
00451
00453 ~eprod() {
00454 for ( int i = 0; i < epdfs.length(); i++ ) {
00455 delete dls ( i );
00456 }
00457 }
00458 };
00459
00460
00464 class mmix : public mpdf {
00465 protected:
00467 Array<shared_ptr<mpdf> > Coms;
00469 vec w;
00471 epdf dummy_epdf;
00472 public:
00474 mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf); }
00475
00477 void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
00479 Coms = Coms0;
00480 w=w0;
00481
00482 if (Coms0.length()>0){
00483 set_rv(Coms(0)->_rv());
00484 dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
00485 set_rvc(Coms(0)->_rvc());
00486 dimc = rvc._dsize();
00487 }
00488 }
00489 double evallogcond (const vec &dt, const vec &cond) {
00490 double ll=0.0;
00491 for (int i=0;i<Coms.length();i++){
00492 ll+=Coms(i)->evallogcond(dt,cond);
00493 }
00494 return ll;
00495 }
00496
00497 vec samplecond (const vec &cond);
00498
00499 };
00500
00501 }
00502 #endif //MX_H