00001 
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015 
00016 
00017 #include "../estim/mixtures.h"
00018 #include "discrete.h"
00019 
00020 namespace bdm {
00021 using std::string;
00022 
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025 
00045 class merger_base : public epdf {
00046 protected:
00048         Array<shared_ptr<mpdf> > mpdfs;
00049 
00051         Array<datalink_m2e*> dls;
00052 
00054         Array<RV> rvzs;
00055 
00057         Array<datalink_m2e*> zdls;
00058 
00060         int Npoints;
00061 
00063         int Nsources;
00064 
00066         MERGER_METHOD METHOD;
00068         static const MERGER_METHOD DFLT_METHOD;
00069 
00071         double beta;
00073         static const double DFLT_beta;
00074 
00076         eEmp eSmp;
00077 
00079         bool DBG;
00080 
00082         it_file* dbg_file;
00083 public:
00086 
00088         merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
00089         }
00090 
00092         merger_base ( const Array<shared_ptr<mpdf> > &S );
00093 
00095         void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
00096                 mpdfs = Sources;
00097                 Nsources = mpdfs.length();
00098                 
00099                 dls.set_size ( Sources.length() );
00100                 rvzs.set_size ( Sources.length() );
00101                 zdls.set_size ( Sources.length() );
00102 
00103                 rv = get_composite_rv ( mpdfs,  false );
00104 
00105                 RV rvc;
00106                 
00107                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00108                         RV rvx = mpdfs ( i )->_rvc().subt ( rv );
00109                         rvc.add ( rvx ); 
00110                 }
00111 
00112                 
00113                 rv.add ( rvc );
00114                 
00115                 dim = rv._dsize();
00116 
00117                 
00118                 RV xytmp;
00119                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00120                         
00121                         dls ( i ) = new datalink_m2e;
00122                         dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
00123 
00124                         
00125                         xytmp = mpdfs ( i )->_rv();
00126                         xytmp.add ( mpdfs ( i )->_rvc() );
00127                         
00128                         rvzs ( i ) = rv.subt ( xytmp );
00129                         
00130                         zdls ( i ) = new datalink_m2e;
00131                         zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
00132                 };
00133         }
00135         void set_support ( rectangular_support &Sup) {
00136                 Npoints = Sup.points();
00137                 eSmp.set_parameters ( Npoints, false );
00138                 Array<vec> &samples = eSmp._samples();
00139                 eSmp._w() = ones ( Npoints ) / Npoints; 
00140                 
00141                 samples(0)=Sup.first_vec();
00142                 for (int j=1; j < Npoints; j++ ) {
00143                         samples ( j ) = Sup.next_vec();
00144                 }
00145         }
00147         void set_support ( discrete_support &Sup) {
00148                 Npoints = Sup.points();
00149                 eSmp.set_parameters(Sup._Spoints());
00150         }
00152         void set_debug_file ( const string fname ) {
00153                 if ( DBG ) delete dbg_file;
00154                 dbg_file = new it_file ( fname );
00155                 if ( dbg_file ) DBG = true;
00156         }
00158         void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
00159                 METHOD = MTH;
00160                 beta = beta0;
00161         }
00163         void set_support ( const epdf &overall, int N ) {
00164                 eSmp.set_statistics ( overall, N );
00165                 Npoints = N;
00166         }
00167 
00169         virtual ~merger_base() {
00170                 for ( int i = 0; i < Nsources; i++ ) {
00171                         delete dls ( i );
00172                         delete zdls ( i );
00173                 }
00174                 if ( DBG ) delete dbg_file;
00175         };
00177 
00180 
00182         virtual void merge () {
00183                 validate();
00184 
00185                 
00186                 bool OK = true;
00187                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00188                         OK &= ( rvzs ( i )._dsize() == 0 ); 
00189                         OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); 
00190                 }
00191 
00192                 if ( OK ) {
00193                         mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
00194 
00195                         vec emptyvec ( 0 );
00196                         for ( int i = 0; i < mpdfs.length(); i++ ) {
00197                                 for ( int j = 0; j < eSmp._w().length(); j++ ) {
00198                                         lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
00199                                 }
00200                         }
00201 
00202                         vec w_nn = merge_points ( lW );
00203                         vec wtmp = exp ( w_nn - max ( w_nn ) );
00204                         
00205                         eSmp._w() = wtmp / sum ( wtmp );
00206                 } else {
00207                         bdm_error ( "Sources are not compatible - use merger_mix" );
00208                 }
00209         };
00210 
00211 
00213         vec merge_points ( mat &lW );
00214 
00215 
00218         vec mean() const {
00219                 const Vec<double> &w = eSmp._w();
00220                 const Array<vec> &S = eSmp._samples();
00221                 vec tmp = zeros ( dim );
00222                 for ( int i = 0; i < Npoints; i++ ) {
00223                         tmp += w ( i ) * S ( i );
00224                 }
00225                 return tmp;
00226         }
00227         mat covariance() const {
00228                 const vec &w = eSmp._w();
00229                 const Array<vec> &S = eSmp._samples();
00230 
00231                 vec mea = mean();
00232 
00233 
00234 
00235                 mat Tmp = zeros ( dim, dim );
00236                 for ( int i = 0; i < Npoints; i++ ) {
00237                         Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
00238                 }
00239                 return Tmp - outer_product ( mea, mea );
00240         }
00241         vec variance() const {
00242                 const vec &w = eSmp._w();
00243                 const Array<vec> &S = eSmp._samples();
00244 
00245                 vec tmp = zeros ( dim );
00246                 for ( int i = 0; i < Nsources; i++ ) {
00247                         tmp += w ( i ) * pow ( S ( i ), 2 );
00248                 }
00249                 return tmp - pow ( mean(), 2 );
00250         }
00252 
00255 
00257         eEmp& _Smp() {
00258                 return eSmp;
00259         }
00260 
00262         void from_setting ( const Setting& set ) {
00263                 
00264                 
00265                 string meth_str;
00266                 UI::get<string> ( meth_str, set, "method", UI::compulsory );
00267                 if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
00268                         set_method ( ARITHMETIC );
00269                 else {
00270                         if ( !strcmp ( meth_str.c_str(), "geometric" ) )
00271                                 set_method ( GEOMETRIC );
00272                         else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
00273                                 set_method ( LOGNORMAL );
00274                                 set.lookupValue ( "beta", beta );
00275                         }
00276                 }
00277                 string dbg_file;
00278                 if ( UI::get ( dbg_file, set, "dbg_file" ) )
00279                         set_debug_file ( dbg_file );
00280                 
00281         }
00282 
00283         void validate() {
00284                 bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
00285                 bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
00286                 bdm_assert ( isnamed(), "mergers must be named" );
00287         }
00289 };
00290 UIREGISTER ( merger_base );
00291 SHAREDPTR ( merger_base );
00292 
00294 class merger_mix : public merger_base {
00295 protected:
00297         MixEF Mix;
00299         int Ncoms;
00301         double effss_coef;
00303         int stop_niter;
00304 
00306         static const int DFLT_Ncoms;
00308         static const double DFLT_effss_coef;
00309 
00310 public:
00313         merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
00314 
00315         merger_mix ( const Array<shared_ptr<mpdf> > &S ):
00316                 Ncoms(0), effss_coef(0), stop_niter(0) {
00317                 set_sources ( S );
00318         }
00319 
00321         void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
00322                 merger_base::set_sources ( S );
00323                 Nsources = S.length();
00324         }
00325 
00327         void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
00328                 Ncoms = Ncoms0;
00329                 effss_coef = effss_coef0;
00330         }
00332 
00335 
00337         void merge ();
00338 
00340         vec sample () const {
00341                 return Mix.posterior().sample();
00342         }
00344         double evallog ( const vec &dt ) const {
00345                 vec dtf = ones ( dt.length() + 1 );
00346                 dtf.set_subvector ( 0, dt );
00347                 return Mix.logpred ( dtf );
00348         }
00350 
00354         MixEF& _Mix() {
00355                 return Mix;
00356         }
00358         emix* proposal() {
00359                 emix* tmp = Mix.epredictor();
00360                 tmp->set_rv ( rv );
00361                 return tmp;
00362         }
00364         void from_setting ( const Setting& set ) {
00365                 merger_base::from_setting ( set );
00366                 set.lookupValue ( "ncoms", Ncoms );
00367                 set.lookupValue ( "effss_coef", effss_coef );
00368                 set.lookupValue ( "stop_niter", stop_niter );
00369         }
00370 
00372 
00373 };
00374 UIREGISTER ( merger_mix );
00375 SHAREDPTR ( merger_mix );
00376 
00377 }
00378 
00379 #endif // MER_H