00001
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015
00016
00017 #include "../estim/mixtures.h"
00018
00019 namespace bdm {
00020 using std::string;
00021
00023 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00024
00044 class merger_base : public epdf {
00045 protected:
00047 Array<shared_ptr<mpdf> > mpdfs;
00048
00050 Array<datalink_m2e*> dls;
00051
00053 Array<RV> rvzs;
00054
00056 Array<datalink_m2e*> zdls;
00057
00059 int Npoints;
00060
00062 int Nsources;
00063
00065 MERGER_METHOD METHOD;
00067 static const MERGER_METHOD DFLT_METHOD;
00068
00070 double beta;
00072 static const double DFLT_beta;
00073
00075 eEmp eSmp;
00076
00078 bool DBG;
00079
00081 it_file* dbg_file;
00082 public:
00085
00087 merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
00088 }
00089
00091 merger_base ( const Array<shared_ptr<mpdf> > &S );
00092
00094 void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
00095 mpdfs = Sources;
00096 Nsources = mpdfs.length();
00097
00098 dls.set_size ( Sources.length() );
00099 rvzs.set_size ( Sources.length() );
00100 zdls.set_size ( Sources.length() );
00101
00102 rv = get_composite_rv ( mpdfs, false );
00103
00104 RV rvc;
00105
00106 for ( int i = 0; i < mpdfs.length(); i++ ) {
00107 RV rvx = mpdfs ( i )->_rvc().subt ( rv );
00108 rvc.add ( rvx );
00109 }
00110
00111
00112 rv.add ( rvc );
00113
00114 dim = rv._dsize();
00115
00116
00117 RV xytmp;
00118 for ( int i = 0; i < mpdfs.length(); i++ ) {
00119
00120 dls ( i ) = new datalink_m2e;
00121 dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
00122
00123
00124 xytmp = mpdfs ( i )->_rv();
00125 xytmp.add ( mpdfs ( i )->_rvc() );
00126
00127 rvzs ( i ) = rv.subt ( xytmp );
00128
00129 zdls ( i ) = new datalink_m2e;
00130 zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
00131 };
00132 }
00134 void set_support ( const Array<vec> &XYZ, const int dimsize ) {
00135 set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
00136 }
00138 void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
00139 int dim = XYZ.length();
00140 Npoints = prod ( gridsize );
00141 eSmp.set_parameters ( Npoints, false );
00142 Array<vec> &samples = eSmp._samples();
00143 eSmp._w() = ones ( Npoints ) / Npoints;
00144
00145 ivec ind = zeros_i ( dim );
00146 vec smpi ( dim );
00147 vec steps = zeros ( dim );
00148
00149 for ( int j = 0; j < dim; j++ ) {
00150 smpi ( j ) = XYZ ( j ) ( 0 );
00151 it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
00152 steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
00153 }
00154
00155 for ( int i = 0; i < Npoints; i++ ) {
00156
00157 samples ( i ) = smpi;
00158
00159 for ( int j = 0; j < dim; j++ ) {
00160 if ( ind ( j ) == gridsize ( j ) - 1 ) {
00161 ind ( j ) = 0;
00162 smpi ( j ) = XYZ ( j ) ( 0 );
00163
00164 if ( i < Npoints - 1 ) {
00165 ind ( j + 1 ) ++;
00166 smpi ( j + 1 ) += steps ( j + 1 );
00167 break;
00168 }
00169
00170 } else {
00171 ind ( j ) ++;
00172 smpi ( j ) += steps ( j );
00173 break;
00174 }
00175 }
00176 }
00177 }
00179 void set_debug_file ( const string fname ) {
00180 if ( DBG ) delete dbg_file;
00181 dbg_file = new it_file ( fname );
00182 if ( dbg_file ) DBG = true;
00183 }
00185 void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
00186 METHOD = MTH;
00187 beta = beta0;
00188 }
00190 void set_support ( const epdf &overall, int N ) {
00191 eSmp.set_statistics ( overall, N );
00192 Npoints = N;
00193 }
00194
00196 virtual ~merger_base() {
00197 for ( int i = 0; i < Nsources; i++ ) {
00198 delete dls ( i );
00199 delete zdls ( i );
00200 }
00201 if ( DBG ) delete dbg_file;
00202 };
00204
00207
00209 virtual void merge () {
00210 validate();
00211
00212
00213 bool OK = true;
00214 for ( int i = 0; i < mpdfs.length(); i++ ) {
00215 OK &= ( rvzs ( i )._dsize() == 0 );
00216 OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 );
00217 }
00218
00219 if ( OK ) {
00220 mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
00221
00222 vec emptyvec ( 0 );
00223 for ( int i = 0; i < mpdfs.length(); i++ ) {
00224 for ( int j = 0; j < eSmp._w().length(); j++ ) {
00225 lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
00226 }
00227 }
00228
00229 vec w_nn = merge_points ( lW );
00230 vec wtmp = exp ( w_nn - max ( w_nn ) );
00231
00232 eSmp._w() = wtmp / sum ( wtmp );
00233 } else {
00234 it_error ( "Sources are not compatible - use merger_mix" );
00235 }
00236 };
00237
00238
00240 vec merge_points ( mat &lW );
00241
00242
00245 vec mean() const {
00246 const Vec<double> &w = eSmp._w();
00247 const Array<vec> &S = eSmp._samples();
00248 vec tmp = zeros ( dim );
00249 for ( int i = 0; i < Npoints; i++ ) {
00250 tmp += w ( i ) * S ( i );
00251 }
00252 return tmp;
00253 }
00254 mat covariance() const {
00255 const vec &w = eSmp._w();
00256 const Array<vec> &S = eSmp._samples();
00257
00258 vec mea = mean();
00259
00260
00261
00262 mat Tmp = zeros ( dim, dim );
00263 for ( int i = 0; i < Npoints; i++ ) {
00264 Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
00265 }
00266 return Tmp - outer_product ( mea, mea );
00267 }
00268 vec variance() const {
00269 const vec &w = eSmp._w();
00270 const Array<vec> &S = eSmp._samples();
00271
00272 vec tmp = zeros ( dim );
00273 for ( int i = 0; i < Nsources; i++ ) {
00274 tmp += w ( i ) * pow ( S ( i ), 2 );
00275 }
00276 return tmp - pow ( mean(), 2 );
00277 }
00279
00282
00284 eEmp& _Smp() {
00285 return eSmp;
00286 }
00287
00289 void from_setting ( const Setting& set ) {
00290
00291
00292 string meth_str;
00293 UI::get<string> ( meth_str, set, "method", UI::compulsory );
00294 if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
00295 set_method ( ARITHMETIC );
00296 else {
00297 if ( !strcmp ( meth_str.c_str(), "geometric" ) )
00298 set_method ( GEOMETRIC );
00299 else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
00300 set_method ( LOGNORMAL );
00301 set.lookupValue ( "beta", beta );
00302 }
00303 }
00304 string dbg_file;
00305 if ( UI::get ( dbg_file, set, "dbg_file" ) )
00306 set_debug_file ( dbg_file );
00307
00308 }
00309
00310 void validate() {
00311 it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
00312 it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
00313 it_assert ( isnamed(), "mergers must be named" );
00314 }
00316 };
00317 UIREGISTER ( merger_base );
00318 SHAREDPTR ( merger_base );
00319
00321 class merger_mix : public merger_base {
00322 protected:
00324 MixEF Mix;
00326 int Ncoms;
00328 double effss_coef;
00330 int stop_niter;
00331
00333 static const int DFLT_Ncoms;
00335 static const double DFLT_effss_coef;
00336
00337 public:
00340 merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
00341
00342 merger_mix ( const Array<shared_ptr<mpdf> > &S ):
00343 Ncoms(0), effss_coef(0), stop_niter(0) {
00344 set_sources ( S );
00345 }
00346
00348 void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
00349 merger_base::set_sources ( S );
00350 Nsources = S.length();
00351 }
00352
00354 void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
00355 Ncoms = Ncoms0;
00356 effss_coef = effss_coef0;
00357 }
00359
00362
00364 void merge ();
00365
00367 vec sample () const {
00368 return Mix.posterior().sample();
00369 }
00371 double evallog ( const vec &dt ) const {
00372 vec dtf = ones ( dt.length() + 1 );
00373 dtf.set_subvector ( 0, dt );
00374 return Mix.logpred ( dtf );
00375 }
00377
00381 MixEF& _Mix() {
00382 return Mix;
00383 }
00385 emix* proposal() {
00386 emix* tmp = Mix.epredictor();
00387 tmp->set_rv ( rv );
00388 return tmp;
00389 }
00391 void from_setting ( const Setting& set ) {
00392 merger_base::from_setting ( set );
00393 set.lookupValue ( "ncoms", Ncoms );
00394 set.lookupValue ( "effss_coef", effss_coef );
00395 set.lookupValue ( "stop_niter", stop_niter );
00396 }
00397
00399
00400 };
00401 UIREGISTER ( merger_mix );
00402 SHAREDPTR ( merger_mix );
00403
00404 }
00405
00406 #endif // MER_H