00001
00013 #ifndef MX_H
00014 #define MX_H
00015
00016 #include "libBM.h"
00017 #include "libEF.h"
00018
00019
00020 using namespace itpp;
00021
00022
00023
00036 class mratio: public mpdf {
00037 protected:
00039 const epdf* nom;
00041 epdf* den;
00043 bool destroynom;
00045 datalink_m2e dl;
00046 public:
00049 mratio ( const epdf* nom0, const RV &rv, bool copy=false ) :mpdf ( rv,nom0->_rv().subt ( rv ) ), dl ( rv,rvc,nom0->_rv() ) {
00050 if ( copy ) {
00051
00052 it_error ( "todo" );
00053 destroynom=true;
00054 }
00055 else {
00056 nom = nom0;
00057 destroynom = false;
00058 }
00059 it_assert_debug ( rvc.length() >0,"Makes no sense to use this object!" );
00060 den = nom->marginal ( rvc );
00061 };
00062 double evalcond ( const vec &val, const vec &cond ) {
00063 vec nom_val ( rv.count() +rvc.count() );
00064 dl.fill_val_cond ( nom_val,val,cond );
00065 return exp ( nom->evalpdflog ( nom_val ) - den->evalpdflog ( cond ) );
00066 }
00068 void ownnom() {destroynom=true;}
00070 ~mratio() {delete den; if ( destroynom ) {delete nom;}}
00071 };
00072
00083 class emix : public epdf {
00084 protected:
00086 vec w;
00088 Array<epdf*> Coms;
00090 bool destroyComs;
00091 public:
00093 emix ( const RV &rv ) : epdf ( rv ) {};
00096 void set_parameters ( const vec &w, const Array<epdf*> &Coms, bool copy=true );
00097
00098 vec sample() const;
00099 vec mean() const {
00100 int i; vec mu = zeros ( rv.count() );
00101 for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
00102 return mu;
00103 }
00104 double evalpdflog ( const vec &val ) const {
00105 int i;
00106 double sum = 0.0;
00107 for ( i = 0;i < w.length();i++ ) {sum += w ( i ) * exp ( Coms ( i )->evalpdflog ( val ) );}
00108 return log ( sum );
00109 };
00110 vec evalpdflog_m ( const mat &Val ) const {
00111 vec x=zeros ( Val.cols() );
00112 for ( int i = 0; i < w.length(); i++ ) {
00113 x+= w ( i ) *exp ( Coms ( i )->evalpdflog_m ( Val ) );
00114 }
00115 return log ( x );
00116 };
00117 mat evalpdflog_M ( const mat &Val ) const {
00118 mat X ( w.length(), Val.cols() );
00119 for ( int i = 0; i < w.length(); i++ ) {
00120 X.set_row ( i, w ( i ) *exp ( Coms ( i )->evalpdflog_m ( Val ) ) );
00121 }
00122 return X;
00123 };
00124
00125 emix* marginal ( const RV &rv ) const;
00126 mratio* condition ( const RV &rv ) const;
00127
00128
00130 vec& _w() {return w;}
00131 virtual ~emix() {if ( destroyComs ) {for ( int i=0;i<Coms.length();i++ ) {delete Coms ( i );}}}
00133 void ownComs() {destroyComs=true;}
00134
00136 epdf* _Coms ( int i ) {return Coms ( i );}
00137 };
00138
00147 class mprod: public compositepdf, public mpdf {
00148 protected:
00150 Array<epdf*> epdfs;
00152 Array<datalink_m2m*> dls;
00153 public:
00156 mprod ( Array<mpdf*> mFacs ) : compositepdf ( mFacs ), mpdf ( getrv ( true ),RV() ), epdfs ( n ), dls ( n ) {
00157 setrvc ( rv,rvc );
00158
00159 for ( int i = 0;i < n;i++ ) {
00160 dls ( i ) = new datalink_m2m ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv, rvc );
00161 }
00162
00163 for ( int i=0;i<n;i++ ) {
00164 epdfs ( i ) =& ( mpdfs ( i )->_epdf() );
00165 }
00166 };
00167
00168 double evalcond ( const vec &val, const vec &cond ) {
00169 int i;
00170 double res = 1.0;
00171 for ( i = n - 1;i >= 0;i-- ) {
00172
00173
00174
00175
00176
00177 res *= mpdfs ( i )->evalcond (
00178 dls ( i )->get_val ( val ),
00179 dls ( i )->get_cond ( val, cond )
00180 );
00181 }
00182 return res;
00183 }
00184 vec samplecond ( const vec &cond, double &ll ) {
00186 vec smp= std::numeric_limits<double>::infinity() * ones ( rv.count() );
00187 vec smpi;
00188 ll = 0;
00189
00190 for ( int i = ( n - 1 );i >= 0;i-- ) {
00191 if ( mpdfs ( i )->_rvc().count() ) {
00192 mpdfs ( i )->condition ( dls ( i )->get_cond ( smp ,cond ) );
00193 }
00194 smpi = epdfs ( i )->sample();
00195
00196 dls ( i )->fill_val ( smp, smpi );
00197
00198 ll+=epdfs ( i )->evalpdflog ( smpi );
00199 }
00200 return smp;
00201 }
00202 mat samplecond ( const vec &cond, vec &ll, int N ) {
00203 mat Smp ( rv.count(),N );
00204 for ( int i=0;i<N;i++ ) {Smp.set_col ( i,samplecond ( cond,ll ( i ) ) );}
00205 return Smp;
00206 }
00207
00208 ~mprod() {};
00209 };
00210
00212 class eprod: public epdf {
00213 protected:
00215 Array<const epdf*> epdfs;
00217 Array<datalink_e2e*> dls;
00218 public:
00219 eprod ( const Array<const epdf*> epdfs0 ) : epdf ( RV() ),epdfs ( epdfs0 ),dls ( epdfs.length() ) {
00220 bool independent=true;
00221 for ( int i=0;i<epdfs.length();i++ ) {
00222 independent=rv.add ( epdfs ( i )->_rv() );
00223 it_assert_debug ( independent==true, "eprod:: given components are not independent ." );
00224 }
00225 for ( int i=0;i<epdfs.length();i++ ) {
00226 dls ( i ) = new datalink_e2e ( epdfs ( i )->_rv() , rv );
00227 }
00228 }
00229
00230 vec mean() const {
00231 vec tmp ( rv.count() );
00232 for ( int i=0;i<epdfs.length();i++ ) {
00233 vec pom = epdfs ( i )->mean();
00234 dls ( i )->fill_val ( tmp, pom );
00235 }
00236 return tmp;
00237 }
00238 vec sample() const {
00239 vec tmp ( rv.count() );
00240 for ( int i=0;i<epdfs.length();i++ ) {
00241 vec pom = epdfs ( i )->sample();
00242 dls ( i )->fill_val ( tmp, pom );
00243 }
00244 return tmp;
00245 }
00246 double evalpdflog ( const vec &val ) const {
00247 double tmp=0;
00248 for ( int i=0;i<epdfs.length();i++ ) {
00249 tmp+=epdfs ( i )->evalpdflog ( dls ( i )->get_val ( val ) );
00250 }
00251 return tmp;
00252 }
00254 const epdf* operator () ( int i ) const {it_assert_debug ( i<epdfs.length(),"wrong index" );return epdfs ( i );}
00255
00257 ~eprod() {for ( int i=0;i<epdfs.length();i++ ) {delete dls ( i );}}
00258 };
00259
00260
00264 class mmix : public mpdf {
00265 protected:
00267 Array<mpdf*> Coms;
00269 emix Epdf;
00270 public:
00272 mmix ( RV &rv, RV &rvc ) : mpdf ( rv, rvc ), Epdf ( rv ) {ep = &Epdf;};
00274 void set_parameters ( const vec &w, const Array<mpdf*> &Coms ) {
00275 Array<epdf*> Eps ( Coms.length() );
00276
00277 for ( int i = 0;i < Coms.length();i++ ) {
00278 Eps ( i ) = & ( Coms ( i )->_epdf() );
00279 }
00280 Epdf.set_parameters ( w, Eps );
00281 };
00282
00283 void condition ( const vec &cond ) {
00284 for ( int i = 0;i < Coms.length();i++ ) {Coms ( i )->condition ( cond );}
00285 };
00286 };
00287
00288 #endif //MX_H