00001 
00013 #ifndef EMIX_H
00014 #define EMIX_H
00015 
00016 #define LOG2  0.69314718055995
00017 
00018 #include "../shared_ptr.h"
00019 #include "exp_family.h"
00020 
00021 namespace bdm {
00022 
00023 
00024 
00037 class mratio: public mpdf {
00038 protected:
00040         const epdf* nom;
00041 
00043         shared_ptr<epdf> den;
00044 
00046         bool destroynom;
00048         datalink_m2e dl;
00050         epdf iepdf;
00051 public:
00054         mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
00055                 
00056                 rvc = nom0->_rv().subt ( rv );
00057                 dimc = rvc._dsize();
00058                 set_ep ( iepdf );
00059                 iepdf.set_parameters ( rv._dsize() );
00060                 iepdf.set_rv ( rv );
00061 
00062                 
00063                 if ( copy ) {
00064                         bdm_error ( "todo" );
00065                         
00066                 } else {
00067                         nom = nom0;
00068                         destroynom = false;
00069                 }
00070                 bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );
00071 
00072                 
00073                 den = nom->marginal ( rvc );
00074                 dl.set_connection ( rv, rvc, nom0->_rv() );
00075         };
00076         double evallogcond ( const vec &val, const vec &cond ) {
00077                 double tmp;
00078                 vec nom_val ( dimension() + dimc );
00079                 dl.pushup_cond ( nom_val, val, cond );
00080                 tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
00081                 return tmp;
00082         }
00084         void ownnom() {
00085                 destroynom = true;
00086         }
00088         ~mratio() {
00089                 if ( destroynom ) {
00090                         delete nom;
00091                 }
00092         }
00093 
00094 private:
00095         
00096         mratio ( const mratio & );
00097         mratio &operator=( const mratio & );
00098 };
00099 
00110 class emix : public epdf {
00111 protected:
00113         vec w;
00114 
00116         Array<shared_ptr<epdf> > Coms;
00117 
00118 public:
00120         emix ( ) : epdf ( ) { }
00121 
00128         void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );
00129 
00130         vec sample() const;
00131         vec mean() const {
00132                 int i;
00133                 vec mu = zeros ( dim );
00134                 for ( i = 0; i < w.length(); i++ ) {
00135                         mu += w ( i ) * Coms ( i )->mean();
00136                 }
00137                 return mu;
00138         }
00139         vec variance() const {
00140                 
00141                 vec mom2 = zeros ( dim );
00142                 for ( int i = 0; i < w.length(); i++ ) {
00143                         mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
00144                 }
00145                 
00146                 return mom2 - pow ( mean(), 2 );
00147         }
00148         double evallog ( const vec &val ) const {
00149                 int i;
00150                 double sum = 0.0;
00151                 for ( i = 0; i < w.length(); i++ ) {
00152                         sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
00153                 }
00154                 if ( sum == 0.0 ) {
00155                         sum = std::numeric_limits<double>::epsilon();
00156                 }
00157                 double tmp = log ( sum );
00158                 bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
00159                 return tmp;
00160         };
00161         vec evallog_m ( const mat &Val ) const {
00162                 vec x = zeros ( Val.cols() );
00163                 for ( int i = 0; i < w.length(); i++ ) {
00164                         x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
00165                 }
00166                 return log ( x );
00167         };
00169         mat evallog_M ( const mat &Val ) const {
00170                 mat X ( w.length(), Val.cols() );
00171                 for ( int i = 0; i < w.length(); i++ ) {
00172                         X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
00173                 }
00174                 return X;
00175         };
00176 
00177         shared_ptr<epdf> marginal ( const RV &rv ) const;
00179         void marginal ( const RV &rv, emix &target ) const;
00180         shared_ptr<mpdf> condition ( const RV &rv ) const;
00181 
00182 
00184         vec& _w() {
00185                 return w;
00186         }
00187 
00189         shared_ptr<epdf> _Coms ( int i ) {
00190                 return Coms ( i );
00191         }
00192 
00193         void set_rv ( const RV &rv ) {
00194                 epdf::set_rv ( rv );
00195                 for ( int i = 0; i < Coms.length(); i++ ) {
00196                         Coms ( i )->set_rv ( rv );
00197                 }
00198         }
00199 };
00200 SHAREDPTR( emix );
00201 
00206 class egiwmix : public egiw {
00207 protected:
00209         vec w;
00211         Array<egiw*> Coms;
00213         bool destroyComs;
00214 public:
00216         egiwmix ( ) : egiw ( ) {};
00217 
00220         void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );
00221 
00223         vec mean() const;
00224 
00226         vec sample() const;
00227 
00229         vec variance() const;
00230 
00231         
00232         void mean_mat ( mat &M, mat&R ) const {};
00233         double evallog_nn ( const vec &val ) const {
00234                 return 0;
00235         };
00236         double lognc () const {
00237                 return 0;
00238         }
00239 
00240         shared_ptr<epdf> marginal ( const RV &rv ) const;
00242         void marginal ( const RV &rv, emix &target ) const;
00243 
00244 
00246         vec& _w() {
00247                 return w;
00248         }
00249         virtual ~egiwmix() {
00250                 if ( destroyComs ) {
00251                         for ( int i = 0; i < Coms.length(); i++ ) {
00252                                 delete Coms ( i );
00253                         }
00254                 }
00255         }
00257         void ownComs() {
00258                 destroyComs = true;
00259         }
00260 
00262         egiw* _Coms ( int i ) {
00263                 return Coms ( i );
00264         }
00265 
00266         void set_rv ( const RV &rv ) {
00267                 egiw::set_rv ( rv );
00268                 for ( int i = 0; i < Coms.length(); i++ ) {
00269                         Coms ( i )->set_rv ( rv );
00270                 }
00271         }
00272 
00274         egiw* approx();
00275 };
00276 
00285 class mprod: public mpdf {
00286 private:
00287         Array<shared_ptr<mpdf> > mpdfs;
00288 
00290         Array<shared_ptr<datalink_m2m> > dls;
00291 
00292 protected:
00294         epdf iepdf;
00295 
00296 public:
00298         mprod() { }
00299 
00302         mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
00303                 set_elements ( mFacs );
00304         }
00306         void set_elements (const Array<shared_ptr<mpdf> > &mFacs );
00307 
00308         double evallogcond ( const vec &val, const vec &cond ) {
00309                 int i;
00310                 double res = 0.0;
00311                 for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
00312                         
00313 
00314 
00315 
00316 
00317                         res += mpdfs ( i )->evallogcond (
00318                                    dls ( i )->pushdown ( val ),
00319                                    dls ( i )->get_cond ( val, cond )
00320                                );
00321                 }
00322                 return res;
00323         }
00324         vec evallogcond_m ( const mat &Dt, const vec &cond ) {
00325                 vec tmp ( Dt.cols() );
00326                 for ( int i = 0; i < Dt.cols(); i++ ) {
00327                         tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
00328                 }
00329                 return tmp;
00330         };
00331         vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
00332                 vec tmp ( Dt.length() );
00333                 for ( int i = 0; i < Dt.length(); i++ ) {
00334                         tmp ( i ) = evallogcond ( Dt ( i ), cond );
00335                 }
00336                 return tmp;
00337         };
00338 
00339 
00340         
00341         vec samplecond ( const vec &cond ) {
00343                 vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
00344                 vec smpi;
00345                 
00346                 for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
00347                         
00348                         smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));                        
00349                         
00350                         dls ( i )->pushup ( smp, smpi );
00351                 }
00352                 return smp;
00353         }
00354 
00362         void from_setting ( const Setting &set ) {
00363                 Array<shared_ptr<mpdf> > atmp; 
00364                 UI::get ( atmp, set, "mpdfs", UI::compulsory );
00365                 set_elements ( atmp );
00366         }
00367 };
00368 UIREGISTER ( mprod );
00369 SHAREDPTR ( mprod );
00370 
00372 class eprod: public epdf {
00373 protected:
00375         Array<const epdf*> epdfs;
00377         Array<datalink*> dls;
00378 public:
00380         eprod () : epdfs ( 0 ), dls ( 0 ) {};
00382         void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
00383                 epdfs = epdfs0;
00384                 dls.set_length ( epdfs.length() );
00385 
00386                 bool independent = true;
00387                 if ( named ) {
00388                         for ( int i = 0; i < epdfs.length(); i++ ) {
00389                                 independent = rv.add ( epdfs ( i )->_rv() );
00390                                 bdm_assert_debug ( independent, "eprod:: given components are not independent." );
00391                         }
00392                         dim = rv._dsize();
00393                 } else {
00394                         dim = 0;
00395                         for ( int i = 0; i < epdfs.length(); i++ ) {
00396                                 dim += epdfs ( i )->dimension();
00397                         }
00398                 }
00399                 
00400                 int cumdim = 0;
00401                 int dimi = 0;
00402                 int i;
00403                 for ( i = 0; i < epdfs.length(); i++ ) {
00404                         dls ( i ) = new datalink;
00405                         if ( named ) {
00406                                 dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
00407                         } else {
00408                                 dimi = epdfs ( i )->dimension();
00409                                 dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
00410                                 cumdim += dimi;
00411                         }
00412                 }
00413         }
00414 
00415         vec mean() const {
00416                 vec tmp ( dim );
00417                 for ( int i = 0; i < epdfs.length(); i++ ) {
00418                         vec pom = epdfs ( i )->mean();
00419                         dls ( i )->pushup ( tmp, pom );
00420                 }
00421                 return tmp;
00422         }
00423         vec variance() const {
00424                 vec tmp ( dim ); 
00425                 for ( int i = 0; i < epdfs.length(); i++ ) {
00426                         vec pom = epdfs ( i )->mean();
00427                         dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
00428                 }
00429                 return tmp - pow ( mean(), 2 );
00430         }
00431         vec sample() const {
00432                 vec tmp ( dim );
00433                 for ( int i = 0; i < epdfs.length(); i++ ) {
00434                         vec pom = epdfs ( i )->sample();
00435                         dls ( i )->pushup ( tmp, pom );
00436                 }
00437                 return tmp;
00438         }
00439         double evallog ( const vec &val ) const {
00440                 double tmp = 0;
00441                 for ( int i = 0; i < epdfs.length(); i++ ) {
00442                         tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
00443                 }
00444                 bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
00445                 return tmp;
00446         }
00448         const epdf* operator () ( int i ) const {
00449                 bdm_assert_debug ( i < epdfs.length(), "wrong index" );
00450                 return epdfs ( i );
00451         }
00452 
00454         ~eprod() {
00455                 for ( int i = 0; i < epdfs.length(); i++ ) {
00456                         delete dls ( i );
00457                 }
00458         }
00459 };
00460 
00461 
00465 class mmix : public mpdf {
00466 protected:
00468         Array<shared_ptr<mpdf> > Coms;
00470         vec w;
00472         epdf dummy_epdf;
00473 public:
00475         mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);    }
00476 
00478         void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
00480                 Coms = Coms0;
00481                 w=w0;   
00482 
00483                 if (Coms0.length()>0){
00484                         set_rv(Coms(0)->_rv());
00485                         dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
00486                         set_rvc(Coms(0)->_rvc());
00487                         dimc = rvc._dsize();
00488                 }
00489         }
00490         double evallogcond (const vec &dt, const vec &cond) {
00491                 double ll=0.0;
00492                 for (int i=0;i<Coms.length();i++){
00493                         ll+=Coms(i)->evallogcond(dt,cond);
00494                 }
00495                 return ll;
00496         }
00497 
00498         vec samplecond (const vec &cond);
00499 
00500 };
00501 
00502 }
00503 #endif //MX_H