00001
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015
00016
00017 #include "../estim/mixtures.h"
00018 #include "discrete.h"
00019
00020 namespace bdm {
00021 using std::string;
00022
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025
00045 class merger_base : public epdf {
00046 protected:
00048 Array<shared_ptr<mpdf> > mpdfs;
00049
00051 Array<datalink_m2e*> dls;
00052
00054 Array<RV> rvzs;
00055
00057 Array<datalink_m2e*> zdls;
00058
00060 int Npoints;
00061
00063 int Nsources;
00064
00066 MERGER_METHOD METHOD;
00068 static const MERGER_METHOD DFLT_METHOD;
00069
00071 double beta;
00073 static const double DFLT_beta;
00074
00076 eEmp eSmp;
00077
00079 bool DBG;
00080
00082 it_file* dbg_file;
00083 public:
00086
00088 merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
00089 }
00090
00092 merger_base ( const Array<shared_ptr<mpdf> > &S );
00093
00095 void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
00096 mpdfs = Sources;
00097 Nsources = mpdfs.length();
00098
00099 dls.set_size ( Sources.length() );
00100 rvzs.set_size ( Sources.length() );
00101 zdls.set_size ( Sources.length() );
00102
00103 rv = get_composite_rv ( mpdfs, false );
00104
00105 RV rvc;
00106
00107 for ( int i = 0; i < mpdfs.length(); i++ ) {
00108 RV rvx = mpdfs ( i )->_rvc().subt ( rv );
00109 rvc.add ( rvx );
00110 }
00111
00112
00113 rv.add ( rvc );
00114
00115 dim = rv._dsize();
00116
00117
00118 RV xytmp;
00119 for ( int i = 0; i < mpdfs.length(); i++ ) {
00120
00121 dls ( i ) = new datalink_m2e;
00122 dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
00123
00124
00125 xytmp = mpdfs ( i )->_rv();
00126 xytmp.add ( mpdfs ( i )->_rvc() );
00127
00128 rvzs ( i ) = rv.subt ( xytmp );
00129
00130 zdls ( i ) = new datalink_m2e;
00131 zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
00132 };
00133 }
00135 void set_support ( rectangular_support &Sup) {
00136 Npoints = Sup.points();
00137 eSmp.set_parameters ( Npoints, false );
00138 Array<vec> &samples = eSmp._samples();
00139 eSmp._w() = ones ( Npoints ) / Npoints;
00140
00141 samples(0)=Sup.first_vec();
00142 for (int j=1; j < Npoints; j++ ) {
00143 samples ( j ) = Sup.next_vec();
00144 }
00145 }
00147 void set_support ( discrete_support &Sup) {
00148 Npoints = Sup.points();
00149 eSmp.set_parameters(Sup._Spoints());
00150 }
00152 void set_debug_file ( const string fname ) {
00153 if ( DBG ) delete dbg_file;
00154 dbg_file = new it_file ( fname );
00155 if ( dbg_file ) DBG = true;
00156 }
00158 void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
00159 METHOD = MTH;
00160 beta = beta0;
00161 }
00163 void set_support ( const epdf &overall, int N ) {
00164 eSmp.set_statistics ( overall, N );
00165 Npoints = N;
00166 }
00167
00169 virtual ~merger_base() {
00170 for ( int i = 0; i < Nsources; i++ ) {
00171 delete dls ( i );
00172 delete zdls ( i );
00173 }
00174 if ( DBG ) delete dbg_file;
00175 };
00177
00180
00182 virtual void merge () {
00183 validate();
00184
00185
00186 bool OK = true;
00187 for ( int i = 0; i < mpdfs.length(); i++ ) {
00188 OK &= ( rvzs ( i )._dsize() == 0 );
00189 OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 );
00190 }
00191
00192 if ( OK ) {
00193 mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
00194
00195 vec emptyvec ( 0 );
00196 for ( int i = 0; i < mpdfs.length(); i++ ) {
00197 for ( int j = 0; j < eSmp._w().length(); j++ ) {
00198 lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
00199 }
00200 }
00201
00202 vec w_nn = merge_points ( lW );
00203 vec wtmp = exp ( w_nn - max ( w_nn ) );
00204
00205 eSmp._w() = wtmp / sum ( wtmp );
00206 } else {
00207 bdm_error ( "Sources are not compatible - use merger_mix" );
00208 }
00209 };
00210
00211
00213 vec merge_points ( mat &lW );
00214
00215
00218 vec mean() const {
00219 const Vec<double> &w = eSmp._w();
00220 const Array<vec> &S = eSmp._samples();
00221 vec tmp = zeros ( dim );
00222 for ( int i = 0; i < Npoints; i++ ) {
00223 tmp += w ( i ) * S ( i );
00224 }
00225 return tmp;
00226 }
00227 mat covariance() const {
00228 const vec &w = eSmp._w();
00229 const Array<vec> &S = eSmp._samples();
00230
00231 vec mea = mean();
00232
00233
00234
00235 mat Tmp = zeros ( dim, dim );
00236 for ( int i = 0; i < Npoints; i++ ) {
00237 Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
00238 }
00239 return Tmp - outer_product ( mea, mea );
00240 }
00241 vec variance() const {
00242 const vec &w = eSmp._w();
00243 const Array<vec> &S = eSmp._samples();
00244
00245 vec tmp = zeros ( dim );
00246 for ( int i = 0; i < Nsources; i++ ) {
00247 tmp += w ( i ) * pow ( S ( i ), 2 );
00248 }
00249 return tmp - pow ( mean(), 2 );
00250 }
00252
00255
00257 eEmp& _Smp() {
00258 return eSmp;
00259 }
00260
00262 void from_setting ( const Setting& set ) {
00263
00264
00265 string meth_str;
00266 UI::get<string> ( meth_str, set, "method", UI::compulsory );
00267 if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
00268 set_method ( ARITHMETIC );
00269 else {
00270 if ( !strcmp ( meth_str.c_str(), "geometric" ) )
00271 set_method ( GEOMETRIC );
00272 else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
00273 set_method ( LOGNORMAL );
00274 set.lookupValue ( "beta", beta );
00275 }
00276 }
00277 string dbg_file;
00278 if ( UI::get ( dbg_file, set, "dbg_file" ) )
00279 set_debug_file ( dbg_file );
00280
00281 }
00282
00283 void validate() {
00284 bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
00285 bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
00286 bdm_assert ( isnamed(), "mergers must be named" );
00287 }
00289 };
00290 UIREGISTER ( merger_base );
00291 SHAREDPTR ( merger_base );
00292
00294 class merger_mix : public merger_base {
00295 protected:
00297 MixEF Mix;
00299 int Ncoms;
00301 double effss_coef;
00303 int stop_niter;
00304
00306 static const int DFLT_Ncoms;
00308 static const double DFLT_effss_coef;
00309
00310 public:
00313 merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
00314
00315 merger_mix ( const Array<shared_ptr<mpdf> > &S ):
00316 Ncoms(0), effss_coef(0), stop_niter(0) {
00317 set_sources ( S );
00318 }
00319
00321 void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
00322 merger_base::set_sources ( S );
00323 Nsources = S.length();
00324 }
00325
00327 void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
00328 Ncoms = Ncoms0;
00329 effss_coef = effss_coef0;
00330 }
00332
00335
00337 void merge ();
00338
00340 vec sample () const {
00341 return Mix.posterior().sample();
00342 }
00344 double evallog ( const vec &dt ) const {
00345 vec dtf = ones ( dt.length() + 1 );
00346 dtf.set_subvector ( 0, dt );
00347 return Mix.logpred ( dtf );
00348 }
00350
00354 MixEF& _Mix() {
00355 return Mix;
00356 }
00358 emix* proposal() {
00359 emix* tmp = Mix.epredictor();
00360 tmp->set_rv ( rv );
00361 return tmp;
00362 }
00364 void from_setting ( const Setting& set ) {
00365 merger_base::from_setting ( set );
00366 set.lookupValue ( "ncoms", Ncoms );
00367 set.lookupValue ( "effss_coef", effss_coef );
00368 set.lookupValue ( "stop_niter", stop_niter );
00369 }
00370
00372
00373 };
00374 UIREGISTER ( merger_mix );
00375 SHAREDPTR ( merger_mix );
00376
00377 }
00378
00379 #endif // MER_H