00001 
00013 #ifndef MX_H
00014 #define MX_H
00015 
00016 #define LOG2  0.69314718055995  
00017 
00018 #include "libBM.h"
00019 #include "libEF.h"
00020 
00021 
00022 namespace bdm {
00023 
00024 
00025 
00038 class mratio: public mpdf {
00039 protected:
00041         const epdf* nom;
00043         epdf* den;
00045         bool destroynom;
00047         datalink_m2e dl;
00048 public:
00051         mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( ), dl ( ) {
00052                 
00053                 rvc = nom0->_rv().subt ( rv );
00054                 dimc = rvc._dsize();
00055                 ep = new epdf;
00056                 ep->set_parameters(rv._dsize());
00057                 ep->set_rv(rv);
00058                 
00059                 
00060                 if ( copy ) {it_error ( "todo" ); destroynom=true; }
00061                 else { nom = nom0; destroynom = false; }
00062                 it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );               
00063                 
00064                 
00065                 den = nom->marginal ( rvc );
00066                 dl.set_connection(rv,rvc,nom0->_rv());
00067         };
00068         double evallogcond ( const vec &val, const vec &cond ) {
00069                 double tmp;
00070                 vec nom_val ( ep->dimension() + dimc );
00071                 dl.pushup_cond ( nom_val,val,cond );
00072                 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00073                 it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );
00074                 return tmp;
00075         }
00077         void ownnom() {destroynom=true;}
00079         ~mratio() {delete den; if ( destroynom ) {delete nom;}}
00080 };
00081 
00092 class emix : public epdf {
00093 protected:
00095         vec w;
00097         Array<epdf*> Coms;
00099         bool destroyComs;
00100 public:
00102         emix ( ) : epdf ( ) {};
00105         void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=false );
00106 
00107         vec sample() const;
00108         vec mean() const {
00109                 int i; vec mu = zeros ( dim );
00110                 for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
00111                 return mu;
00112         }
00113         vec variance() const {
00114                 
00115                 vec mom2 = zeros ( dim );
00116                 for ( int i = 0;i < w.length();i++ ) {mom2 += w ( i ) * (Coms(i)->variance() + pow ( Coms ( i )->mean(),2 )); }
00117                 
00118                 return mom2-pow ( mean(),2 );
00119         }
00120         double evallog ( const vec &val ) const {
00121                 int i;
00122                 double sum = 0.0;
00123                 for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );}
00124                 if ( sum==0.0 ) {sum=std::numeric_limits<double>::epsilon();}
00125                 double tmp=log ( sum );
00126                 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00127                 return tmp;
00128         };
00129         vec evallog_m ( const mat &Val ) const {
00130                 vec x=zeros ( Val.cols() );
00131                 for ( int i = 0; i < w.length(); i++ ) {
00132                         x+= w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) );
00133                 }
00134                 return log ( x );
00135         };
00137         mat evallog_M ( const mat &Val ) const {
00138                 mat X ( w.length(), Val.cols() );
00139                 for ( int i = 0; i < w.length(); i++ ) {
00140                         X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00141                 }
00142                 return X;
00143         };
00144 
00145         emix* marginal ( const RV &rv ) const;
00146         mratio* condition ( const RV &rv ) const; 
00147 
00148 
00150         vec& _w() {return w;}
00151         virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00153         void ownComs() {destroyComs=true;}
00154 
00156         epdf* _Coms ( int i ) {return Coms ( i );}
00157         void set_rv(const RV &rv){
00158                 epdf::set_rv(rv);
00159                 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00160         }
00161 };
00162 
00163 
00168 class egiwmix : public egiw {
00169 protected:
00171         vec w;
00173         Array<egiw*> Coms;
00175         bool destroyComs;
00176 public:
00178         egiwmix ( ) : egiw ( ) {};
00179 
00182         void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy=false );
00183 
00185         vec mean() const;
00186 
00188         vec sample() const;
00189 
00191         vec variance() const;
00192 
00193         
00194         void mean_mat ( mat &M, mat&R ) const {};
00195         double evallog_nn ( const vec &val ) const {return 0;};
00196         double lognc () const {return 0;};
00197         emix* marginal ( const RV &rv ) const;
00198 
00199 
00201         vec& _w() {return w;}
00202         virtual ~egiwmix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00204         void ownComs() {destroyComs=true;}
00205 
00207         egiw* _Coms ( int i ) {return Coms ( i );}
00208 
00209         void set_rv(const RV &rv){
00210                 egiw::set_rv(rv);
00211                 for(int i=0;i<Coms.length();i++){Coms(i)->set_rv(rv);}
00212         }
00213 
00215         egiw* approx();
00216 };
00217 
00226 class mprod: public compositepdf, public mpdf {
00227 protected:
00229         Array<epdf*> epdfs;
00231         Array<datalink_m2m*> dls;
00233         epdf dummy;
00234 public:
00237         mprod ( Array<mpdf*> mFacs ) : compositepdf ( mFacs ), mpdf (), epdfs ( n ), dls ( n ) {
00238                 ep=&dummy;
00239                 RV rv=getrv ( true );
00240                 set_rv ( rv );dummy.set_parameters ( rv._dsize() );
00241                 setrvc ( ep->_rv(),rvc );
00242                 
00243                 for ( int i = 0;i < n;i++ ) {
00244                         dls ( i ) = new datalink_m2m;
00245                         dls(i)->set_connection( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
00246                 }
00247 
00248                 for ( int i=0;i<n;i++ ) {
00249                         epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
00250                 }
00251         };
00252 
00253         double evallogcond ( const vec &val, const vec &cond ) {
00254                 int i;
00255                 double res = 0.0;
00256                 for ( i = n - 1;i >= 0;i-- ) {
00257                         
00258 
00259 
00260 
00261 
00262                         res += mpdfs ( i )->evallogcond (
00263                                    dls ( i )->pushdown ( val ),
00264                                    dls ( i )->get_cond ( val, cond )
00265                                );
00266                 }
00267                 return res;
00268         }
00269         
00270         vec samplecond ( const vec &cond ) {
00272                 vec smp= std::numeric_limits<double>::infinity() * ones ( ep->dimension() );
00273                 vec smpi;
00274                 
00275                 for ( int i = ( n - 1 );i >= 0;i-- ) {
00276                         if ( mpdfs ( i )->dimensionc() ) {
00277                                 mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) ); 
00278                         }
00279                         smpi = epdfs ( i )->sample();
00280                         
00281                         dls ( i )->pushup ( smp, smpi );
00282                         
00283                 }
00284                 return smp;
00285         }
00286         mat samplecond ( const vec &cond,  int N ) {
00287                 mat Smp ( dimension(),N );
00288                 for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond ) );}
00289                 return Smp;
00290         }
00291 
00292         ~mprod() {};
00293 };
00294 
00296 class eprod: public epdf {
00297 protected:
00299         Array<const epdf*> epdfs;
00301         Array<datalink*> dls;
00302 public:
00303         eprod () : epdfs ( 0 ),dls ( 0 ) {};
00304         void set_parameters ( const Array<const epdf*> &epdfs0, bool named=true ) {
00305                 epdfs=epdfs0;
00306                 dls.set_length ( epdfs.length() );
00307 
00308                 bool independent=true;
00309                 if ( named ) {
00310                         for ( int i=0;i<epdfs.length();i++ ) {
00311                                 independent=rv.add ( epdfs ( i )->_rv() );
00312                                 it_assert_debug ( independent==true, "eprod:: given components are not independent." );
00313                         }
00314                         dim=rv._dsize();
00315                 }
00316                 else {
00317                         dim =0; for ( int i=0;i<epdfs.length();i++ ) {
00318                                 dim+=epdfs ( i )->dimension();
00319                         }
00320                 }
00321                 
00322                 int cumdim=0;
00323                 int dimi=0;
00324                 int i;
00325                 for ( i=0;i<epdfs.length();i++ ) {
00326                         dls ( i ) = new datalink;
00327                         if ( named ) {dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );}
00328                         else {
00329                                 dimi = epdfs ( i )->dimension();
00330                                 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim,cumdim+dimi-1 ) );
00331                                 cumdim+=dimi;
00332                         }
00333                 }
00334         }
00335 
00336         vec mean() const {
00337                 vec tmp ( dim );
00338                 for ( int i=0;i<epdfs.length();i++ ) {
00339                         vec pom = epdfs ( i )->mean();
00340                         dls ( i )->pushup ( tmp, pom );
00341                 }
00342                 return tmp;
00343         }
00344         vec variance() const {
00345                 vec tmp ( dim ); 
00346                 for ( int i=0;i<epdfs.length();i++ ) {
00347                         vec pom = epdfs ( i )->mean();
00348                         dls ( i )->pushup ( tmp, pow ( pom,2 ) );
00349                 }
00350                 return tmp-pow ( mean(),2 );
00351         }
00352         vec sample() const {
00353                 vec tmp ( dim );
00354                 for ( int i=0;i<epdfs.length();i++ ) {
00355                         vec pom = epdfs ( i )->sample();
00356                         dls ( i )->pushup ( tmp, pom );
00357                 }
00358                 return tmp;
00359         }
00360         double evallog ( const vec &val ) const {
00361                 double tmp=0;
00362                 for ( int i=0;i<epdfs.length();i++ ) {
00363                         tmp+=epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00364                 }
00365                 it_assert_debug ( std::isfinite ( tmp ),"Infinite" );
00366                 return tmp;
00367         }
00369         const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
00370 
00372         ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
00373 };
00374 
00375 
00379 class mmix : public mpdf {
00380 protected:
00382         Array<mpdf*> Coms;
00384         emix Epdf;
00385 public:
00387         mmix ( ) : mpdf ( ), Epdf () {ep = &Epdf;};
00389         void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
00390                 Array<epdf*> Eps ( Coms.length() );
00391 
00392                 for ( int i = 0;i < Coms.length();i++ ) {
00393                         Eps ( i ) = & ( Coms ( i )->_epdf() );
00394                 }
00395                 Epdf.set_parameters ( w, Eps );
00396         };
00397 
00398         void condition ( const vec &cond ) {
00399                 for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
00400         };
00401 };
00402 
00403 }
00404 #endif //MX_H