00001 
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015 
00016 #define LOG2  0.69314718055995  
00017 
00018 #include "exp_family.h"
00019 
00020 namespace bdm {
00021 
00022 
00023 
00036 class mratio: public mpdf {
00037 protected:
00039         const epdf* nom;
00041         epdf* den;
00043         bool destroynom;
00045         datalink_m2e dl;
00046 public:
00049         mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
00050                 
00051                 rvc = nom0->_rv().subt ( rv );
00052                 dimc = rvc._dsize();
00053                 ep = new epdf;
00054                 ep->set_parameters(rv._dsize());
00055                 ep->set_rv(rv);
00056                 
00057                 
00058                 if ( copy ) {it_error ( "todo" ); destroynom=true; }
00059                 else { nom = nom0; destroynom = false; }
00060                 it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
00061                 
00062                 
00063                 den = nom->marginal ( rvc );
00064                 dl.set_connection(rv,rvc,nom0->_rv());
00065         };
00066         double evallogcond ( const vec &val, const vec &cond ) {
00067                 double tmp;
00068                 vec nom_val ( ep->dimension() + dimc );
00069                 dl.pushup_cond ( nom_val,val,cond );
00070                 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00071                 it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
00072                 return tmp;
00073         }
00075         void ownnom() {destroynom=true;}
00077         ~mratio() {delete den; if ( destroynom ) {delete nom;}}
00078 };
00079 
00090 class emix : public epdf {
00091 protected:
00093         vec w;
00095         Array<epdf*> Coms;
00097         bool destroyComs;
00098 public:
00100         emix ( ) : epdf ( ) {};
00103         void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
00104 
00105         vec sample() const;
00106         vec mean() const {
00107                 int i; vec mu = zeros ( dim );
00108                 for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
00109                 return mu;
00110         }
00111         vec variance() const {
00112                 
00113                 vec mom2 = zeros ( dim );
00114                 for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
00115                 
00116                 return mom2-pow ( mean(),2 );
00117         }
00118         double evallog ( const vec &val ) const {
00119                 int i;
00120                 double sum = 0.0;
00121                 for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
00122                 if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
00123                 double tmp=log ( sum );
00124                 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00125                 return tmp;
00126         };
00127         vec evallog_m ( const mat &Val ) const {
00128                 vec x=zeros ( Val.cols() );
00129                 for ( int i = 0; i < w.length(); i++ ) {
00130                         x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
00131                 }
00132                 return log ( x );
00133         };
00135         mat evallog_M ( const mat &Val ) const {
00136                 mat X ( w.length(), Val.cols() );
00137                 for ( int i = 0; i < w.length(); i++ ) {
00138                         X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00139                 }
00140                 return X;
00141         };
00142 
00143         emix* marginal ( const RV &rv ) const;
00144         mratio* condition ( const RV &rv ) const; 
00145 
00146 
00148         vec& _w() {return w;}
00149         virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00151         void ownComs() {destroyComs=true;}
00152 
00154         epdf* _Coms ( int i ) {return Coms ( i );}
00155         void set_rv(const RV &rv){
00156                 epdf::set_rv(rv);
00157                 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00158         }
00159 };
00160 
00161 
00166 class egiwmix : public egiw {
00167 protected:
00169         vec w;
00171         Array<egiw*> Coms;
00173         bool destroyComs;
00174 public:
00176         egiwmix ( ) : egiw ( ) {};
00177 
00180         void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
00181 
00183         vec mean() const;
00184 
00186         vec sample() const;
00187 
00189         vec variance() const;
00190 
00191         
00192         void mean_mat ( mat &M, mat&R ) const {};
00193         double evallog_nn ( const vec &val ) const {return 0;};
00194         double lognc () const {return 0;};
00195         emix* marginal ( const RV &rv ) const;
00196 
00197 
00199         vec& _w() {return w;}
00200         virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00202         void ownComs() {destroyComs=true;}
00203 
00205         egiw* _Coms ( int i ) {return Coms ( i );}
00206 
00207         void set_rv(const RV &rv){
00208                 egiw::set_rv(rv);
00209                 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00210         }
00211 
00213         egiw* approx();
00214 };
00215 
00224 class mprod: public compositepdf, public mpdf {
00225 protected:
00227         Array<epdf*> epdfs;
00229         Array<datalink_m2m*> dls;
00231         epdf dummy;
00232 public:
00235         mprod (){};
00236         mprod (Array<mpdf*> mFacs ){set_elements( mFacs );};
00237         void set_elements(Array<mpdf*> mFacs , bool own=false) {
00238                 
00239                 compositepdf::set_elements(mFacs,own);
00240                 dls.set_size(mFacs.length());
00241                 epdfs.set_size(mFacs.length());
00242                                 
00243                 ep=&dummy;
00244                 RV rv=getrv ( true );
00245                 set_rv ( rv );
00246                 dummy.set_parameters ( rv._dsize() );
00247                 setrvc ( ep->_rv(),rvc );
00248                 
00249                 for ( int i = 0;i < mpdfs.length(); i++ ) {
00250                         dls ( i ) = new datalink_m2m;
00251                         dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
00252                 }
00253 
00254                 for ( int i=0; i<mpdfs.length(); i++ ) {
00255                         epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
00256                 }
00257         };
00258 
00259         double evallogcond ( const vec &val, const vec &cond ) {
00260                 int i;
00261                 double res = 0.0;
00262                 for ( i = mpdfs.length() - 1;i >= 0;i-- ) {
00263                         
00264 
00265 
00266 
00267 
00268                         res += mpdfs ( i )->evallogcond (
00269                                    dls ( i )->pushdown ( val ),
00270                                    dls ( i )->get_cond ( val, cond )
00271                                );
00272                 }
00273                 return res;
00274         }
00275         vec evallogcond_m(const mat &Dt, const vec &cond) {
00276                 vec tmp(Dt.cols());
00277                 for(int i=0;i<Dt.cols(); i++){
00278                         tmp(i) = evallogcond(Dt.get_col(i),cond);
00279                 }
00280                 return tmp;
00281         };
00282         vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {
00283                 vec tmp(Dt.length());
00284                 for(int i=0;i<Dt.length(); i++){
00285                         tmp(i) = evallogcond(Dt(i),cond);
00286                 }
00287                 return tmp;             
00288         };
00289 
00290 
00291         
00292         vec samplecond ( const vec &cond ) {
00294                 vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
00295                 vec smpi;
00296                 
00297                 for ( int i = ( mpdfs.length() - 1 );i >= 0;i-- ) {
00298                         if ( mpdfs ( i )->dimensionc() ) {
00299                                 mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); 
00300                         }
00301                         smpi = epdfs ( i )->sample();
00302                         
00303                         dls ( i )->pushup ( smp, smpi );
00304                         
00305                 }
00306                 return smp;
00307         }
00308         mat samplecond ( const vec &cond,  int N ) {
00309                 mat Smp ( dimension(),N );
00310                 for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
00311                 return Smp;
00312         }
00313 
00314         ~mprod() {};
00322         void from_setting(const Setting &set){
00323                 Array<mpdf*> Atmp; 
00324                 UI::get(Atmp,set, "mpdfs");
00325                 set_elements(Atmp,true);
00326         }
00327         
00328 };
00329 UIREGISTER(mprod);
00330 
00332 class eprod: public epdf {
00333 protected:
00335         Array<const epdf*> epdfs;
00337         Array<datalink*> dls;
00338 public:
00339         eprod () : epdfs ( 0 ),dls ( 0 ) {};
00340         void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
00341                 epdfs=epdfs0;
00342                 dls.set_length ( epdfs.length() );
00343 
00344                 bool independent=true;
00345                 if ( named ) {
00346                         for ( int i=0;i<epdfs.length();i++ ) {
00347                                 independent=rv.add ( epdfs ( i )->_rv() );
00348                                 it_assert_debug ( independent==true, "eprod:: given components are not independent." );
00349                         }
00350                         dim=rv._dsize();
00351                 }
00352                 else {
00353                         dim =0; for ( int i=0;i<epdfs.length();i++ ) {
00354                                 dim+=epdfs ( i )->dimension();
00355                         }
00356                 }
00357                 
00358                 int cumdim=0;
00359                 int dimi=0;
00360                 int i;
00361                 for ( i=0;i<epdfs.length();i++ ) {
00362                         dls ( i ) = new datalink;
00363                         if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
00364                         else {
00365                                 dimi = epdfs ( i )->dimension();
00366                                 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
00367                                 cumdim+=dimi;
00368                         }
00369                 }
00370         }
00371 
00372         vec mean() const {
00373                 vec tmp ( dim );
00374                 for ( int i=0;i<epdfs.length();i++ ) {
00375                         vec pom = epdfs ( i )->mean();
00376                         dls ( i )->pushup ( tmp, pom );
00377                 }
00378                 return tmp;
00379         }
00380         vec variance() const {
00381                 vec tmp ( dim ); 
00382                 for ( int i=0;i<epdfs.length();i++ ) {
00383                         vec pom = epdfs ( i )->mean();
00384                         dls ( i )->pushup ( tmp, pow ( pom,2 ) );
00385                 }
00386                 return tmp-pow ( mean(),2 );
00387         }
00388         vec sample() const {
00389                 vec tmp ( dim );
00390                 for ( int i=0;i<epdfs.length();i++ ) {
00391                         vec pom = epdfs ( i )->sample();
00392                         dls ( i )->pushup ( tmp, pom );
00393                 }
00394                 return tmp;
00395         }
00396         double evallog ( const vec &val ) const {
00397                 double tmp=0;
00398                 for ( int i=0;i<epdfs.length();i++ ) {
00399                         tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00400                 }
00401                 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00402                 return tmp;
00403         }
00405         const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
00406 
00408         ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
00409 };
00410 
00411 
00415 class mmix : public mpdf {
00416 protected:
00418         Array<mpdf*> Coms;
00420         emix Epdf;
00421 public:
00423         mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
00425         void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
00426                 Array<epdf*> Eps ( Coms.length() );
00427 
00428                 for ( int i = 0;i < Coms.length();i++ ) {
00429                         Eps ( i ) = & ( Coms ( i )->_epdf() );
00430                 }
00431                 Epdf.set_parameters ( w, Eps );
00432         };
00433 
00434         void condition ( const vec &cond ) {
00435                 for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
00436         };
00437 };
00438 
00439 }
00440 #endif //MX_H