00001 
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015 
00016 
00017 #include "../estim/mixtures.h"
00018 
00019 namespace bdm {
00020 using std::string;
00021 
00023 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00024 
00044 class merger_base : public epdf {
00045 protected:
00047         Array<shared_ptr<mpdf> > mpdfs;
00048 
00050         Array<datalink_m2e*> dls;
00051 
00053         Array<RV> rvzs;
00054 
00056         Array<datalink_m2e*> zdls;
00057 
00059         int Npoints;
00060 
00062         int Nsources;
00063 
00065         MERGER_METHOD METHOD;
00067         static const MERGER_METHOD DFLT_METHOD;
00068 
00070         double beta;
00072         static const double DFLT_beta;
00073 
00075         eEmp eSmp;
00076 
00078         bool DBG;
00079 
00081         it_file* dbg_file;
00082 public:
00085 
00087         merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
00088         }
00089 
00091         merger_base ( const Array<shared_ptr<mpdf> > &S );
00092 
00094         void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
00095                 mpdfs = Sources;
00096                 Nsources = mpdfs.length();
00097                 
00098                 dls.set_size ( Sources.length() );
00099                 rvzs.set_size ( Sources.length() );
00100                 zdls.set_size ( Sources.length() );
00101 
00102                 rv = get_composite_rv ( mpdfs,  false );
00103 
00104                 RV rvc;
00105                 
00106                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00107                         RV rvx = mpdfs ( i )->_rvc().subt ( rv );
00108                         rvc.add ( rvx ); 
00109                 }
00110 
00111                 
00112                 rv.add ( rvc );
00113                 
00114                 dim = rv._dsize();
00115 
00116                 
00117                 RV xytmp;
00118                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00119                         
00120                         dls ( i ) = new datalink_m2e;
00121                         dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
00122 
00123                         
00124                         xytmp = mpdfs ( i )->_rv();
00125                         xytmp.add ( mpdfs ( i )->_rvc() );
00126                         
00127                         rvzs ( i ) = rv.subt ( xytmp );
00128                         
00129                         zdls ( i ) = new datalink_m2e;
00130                         zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
00131                 };
00132         }
00134         void set_support ( const Array<vec> &XYZ, const int dimsize ) {
00135                 set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
00136         }
00138         void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
00139                 int dim = XYZ.length();  
00140                 Npoints = prod ( gridsize );
00141                 eSmp.set_parameters ( Npoints, false );
00142                 Array<vec> &samples = eSmp._samples();
00143                 eSmp._w() = ones ( Npoints ) / Npoints; 
00144                 
00145                 ivec ind = zeros_i ( dim );    
00146                 vec smpi ( dim );          
00147                 vec steps = zeros ( dim );        
00148                 
00149                 for ( int j = 0; j < dim; j++ ) {
00150                         smpi ( j ) = XYZ ( j ) ( 0 ); 
00151                         it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
00152                         steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
00153                 }
00154                 
00155                 for ( int i = 0; i < Npoints; i++ ) {
00156                         
00157                         samples ( i ) = smpi;
00158                         
00159                         for ( int j = 0; j < dim; j++ ) {
00160                                 if ( ind ( j ) == gridsize ( j ) - 1 ) { 
00161                                         ind ( j ) = 0; 
00162                                         smpi ( j ) = XYZ ( j ) ( 0 );
00163 
00164                                         if ( i < Npoints - 1 ) {
00165                                                 ind ( j + 1 ) ++; 
00166                                                 smpi ( j + 1 ) += steps ( j + 1 );
00167                                                 break;
00168                                         }
00169 
00170                                 } else {
00171                                         ind ( j ) ++;
00172                                         smpi ( j ) += steps ( j );
00173                                         break;
00174                                 }
00175                         }
00176                 }
00177         }
00179         void set_debug_file ( const string fname ) {
00180                 if ( DBG ) delete dbg_file;
00181                 dbg_file = new it_file ( fname );
00182                 if ( dbg_file ) DBG = true;
00183         }
00185         void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
00186                 METHOD = MTH;
00187                 beta = beta0;
00188         }
00190         void set_support ( const epdf &overall, int N ) {
00191                 eSmp.set_statistics ( overall, N );
00192                 Npoints = N;
00193         }
00194 
00196         virtual ~merger_base() {
00197                 for ( int i = 0; i < Nsources; i++ ) {
00198                         delete dls ( i );
00199                         delete zdls ( i );
00200                 }
00201                 if ( DBG ) delete dbg_file;
00202         };
00204 
00207 
00209         virtual void merge () {
00210                 validate();
00211 
00212                 
00213                 bool OK = true;
00214                 for ( int i = 0; i < mpdfs.length(); i++ ) {
00215                         OK &= ( rvzs ( i )._dsize() == 0 ); 
00216                         OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); 
00217                 }
00218 
00219                 if ( OK ) {
00220                         mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
00221 
00222                         vec emptyvec ( 0 );
00223                         for ( int i = 0; i < mpdfs.length(); i++ ) {
00224                                 for ( int j = 0; j < eSmp._w().length(); j++ ) {
00225                                         lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
00226                                 }
00227                         }
00228 
00229                         vec w_nn = merge_points ( lW );
00230                         vec wtmp = exp ( w_nn - max ( w_nn ) );
00231                         
00232                         eSmp._w() = wtmp / sum ( wtmp );
00233                 } else {
00234                         it_error ( "Sources are not compatible - use merger_mix" );
00235                 }
00236         };
00237 
00238 
00240         vec merge_points ( mat &lW );
00241 
00242 
00245         vec mean() const {
00246                 const Vec<double> &w = eSmp._w();
00247                 const Array<vec> &S = eSmp._samples();
00248                 vec tmp = zeros ( dim );
00249                 for ( int i = 0; i < Npoints; i++ ) {
00250                         tmp += w ( i ) * S ( i );
00251                 }
00252                 return tmp;
00253         }
00254         mat covariance() const {
00255                 const vec &w = eSmp._w();
00256                 const Array<vec> &S = eSmp._samples();
00257 
00258                 vec mea = mean();
00259 
00260 
00261 
00262                 mat Tmp = zeros ( dim, dim );
00263                 for ( int i = 0; i < Npoints; i++ ) {
00264                         Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
00265                 }
00266                 return Tmp - outer_product ( mea, mea );
00267         }
00268         vec variance() const {
00269                 const vec &w = eSmp._w();
00270                 const Array<vec> &S = eSmp._samples();
00271 
00272                 vec tmp = zeros ( dim );
00273                 for ( int i = 0; i < Nsources; i++ ) {
00274                         tmp += w ( i ) * pow ( S ( i ), 2 );
00275                 }
00276                 return tmp - pow ( mean(), 2 );
00277         }
00279 
00282 
00284         eEmp& _Smp() {
00285                 return eSmp;
00286         }
00287 
00289         void from_setting ( const Setting& set ) {
00290                 
00291                 
00292                 string meth_str;
00293                 UI::get<string> ( meth_str, set, "method", UI::compulsory );
00294                 if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
00295                         set_method ( ARITHMETIC );
00296                 else {
00297                         if ( !strcmp ( meth_str.c_str(), "geometric" ) )
00298                                 set_method ( GEOMETRIC );
00299                         else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
00300                                 set_method ( LOGNORMAL );
00301                                 set.lookupValue ( "beta", beta );
00302                         }
00303                 }
00304                 string dbg_file;
00305                 if ( UI::get ( dbg_file, set, "dbg_file" ) )
00306                         set_debug_file ( dbg_file );
00307                 
00308         }
00309 
00310         void validate() {
00311                 it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
00312                 it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
00313                 it_assert ( isnamed(), "mergers must be named" );
00314         }
00316 };
00317 UIREGISTER ( merger_base );
00318 SHAREDPTR ( merger_base );
00319 
00321 class merger_mix : public merger_base {
00322 protected:
00324         MixEF Mix;
00326         int Ncoms;
00328         double effss_coef;
00330         int stop_niter;
00331 
00333         static const int DFLT_Ncoms;
00335         static const double DFLT_effss_coef;
00336 
00337 public:
00340         merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
00341 
00342         merger_mix ( const Array<shared_ptr<mpdf> > &S ):
00343                 Ncoms(0), effss_coef(0), stop_niter(0) {
00344                 set_sources ( S );
00345         }
00346 
00348         void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
00349                 merger_base::set_sources ( S );
00350                 Nsources = S.length();
00351         }
00352 
00354         void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
00355                 Ncoms = Ncoms0;
00356                 effss_coef = effss_coef0;
00357         }
00359 
00362 
00364         void merge ();
00365 
00367         vec sample () const {
00368                 return Mix.posterior().sample();
00369         }
00371         double evallog ( const vec &dt ) const {
00372                 vec dtf = ones ( dt.length() + 1 );
00373                 dtf.set_subvector ( 0, dt );
00374                 return Mix.logpred ( dtf );
00375         }
00377 
00381         MixEF& _Mix() {
00382                 return Mix;
00383         }
00385         emix* proposal() {
00386                 emix* tmp = Mix.epredictor();
00387                 tmp->set_rv ( rv );
00388                 return tmp;
00389         }
00391         void from_setting ( const Setting& set ) {
00392                 merger_base::from_setting ( set );
00393                 set.lookupValue ( "ncoms", Ncoms );
00394                 set.lookupValue ( "effss_coef", effss_coef );
00395                 set.lookupValue ( "stop_niter", stop_niter );
00396         }
00397 
00399 
00400 };
00401 UIREGISTER ( merger_mix );
00402 SHAREDPTR ( merger_mix );
00403 
00404 }
00405 
00406 #endif // MER_H