00001 
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015 
00016 
00017 #include "../estim/mixtures.h"
00018 
00019 namespace bdm
00020 {
00021 using std::string;
00022 
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025 
00045 class merger_base : public compositepdf, public epdf
00046 {
00047         protected:
00049                 Array<datalink_m2e*> dls;
00051                 Array<RV> rvzs;
00053                 Array<datalink_m2e*> zdls;
00055                 int Npoints;
00057                 int Nsources;
00058 
00060                 MERGER_METHOD METHOD;
00062                 static const MERGER_METHOD DFLT_METHOD;
00063                 
00065                 double beta;
00067                 static const double DFLT_beta;
00068                 
00070                 eEmp eSmp;
00071 
00073                 bool DBG;
00074 
00076                 it_file* dbg_file;
00077         public:
00080 
00082                 merger_base () : compositepdf() {DBG = false;dbg_file = NULL;}
00083 
00085                 merger_base (const Array<mpdf*> &S, bool own=false);
00086 
00088                 void set_sources (const Array<mpdf*> &Sources, bool own) {
00089                         compositepdf::set_elements (Sources,own);
00090                         Nsources=mpdfs.length();
00091                         
00092                         dls.set_size (Sources.length());
00093                         rvzs.set_size (Sources.length());
00094                         zdls.set_size (Sources.length());
00095 
00096                         rv = getrv ( false);
00097                         RV rvc; setrvc (rv, rvc);  
00098                         
00099                         rv.add (rvc);
00100                         
00101                         dim = rv._dsize();
00102 
00103                         
00104                         RV xytmp;
00105                         for (int i = 0;i < mpdfs.length();i++) {
00106                                 
00107                                 dls (i) = new datalink_m2e;
00108                                 dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
00109 
00110                                 
00111                                 xytmp = mpdfs (i)->_rv();
00112                                 xytmp.add (mpdfs (i)->_rvc());
00113                                 
00114                                 rvzs (i) = rv.subt (xytmp);
00115                                 
00116                                 zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
00117                         };
00118                 }
00120                 void set_support (const Array<vec> &XYZ, const int dimsize) {
00121                         set_support(XYZ,dimsize*ones_i(XYZ.length()));
00122                 }
00124                 void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
00125                         int dim = XYZ.length();  
00126                         Npoints = prod (gridsize); 
00127                         eSmp.set_parameters (Npoints, false);
00128                         Array<vec> &samples = eSmp._samples();
00129                         eSmp._w() = ones (Npoints) / Npoints; 
00130                         
00131                         ivec ind = zeros_i (dim);      
00132                         vec smpi (dim);            
00133                         vec steps =zeros(dim);            
00134                         
00135                         for (int j = 0; j < dim; j++) {
00136                                 smpi (j) = XYZ (j) (0);  
00137                                 it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
00138                                 steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
00139                         }
00140                         
00141                         for (int i = 0; i < Npoints; i++) {
00142                                 
00143                                 samples(i) = smpi; 
00144                                 
00145                                 for (int j = 0;j < dim;j++) {
00146                                         if (ind (j) == gridsize (j) - 1) { 
00147                                                 ind (j) = 0; 
00148                                                 smpi(j) = XYZ(j)(0);
00149                                                 
00150                                                 if (i<Npoints-1) {
00151                                                         ind (j + 1) ++; 
00152                                                         smpi(j+1) += steps(j+1);
00153                                                         break;
00154                                                 }
00155                                                 
00156                                         } else {
00157                                                 ind (j) ++; 
00158                                                 smpi(j) +=steps(j);
00159                                                 break;
00160                                         }
00161                                 }
00162                         }
00163                 }
00165                 void set_debug_file (const string fname) {
00166                         if (DBG) delete dbg_file;
00167                         dbg_file = new it_file (fname);
00168                         if (dbg_file) DBG = true;
00169                 }
00171                 void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
00172                         METHOD = MTH;
00173                         beta = beta0;
00174                 }
00176                 void set_support (const epdf &overall, int N) {
00177                         eSmp.set_statistics (&overall, N);
00178                         Npoints = N;
00179                 }
00180 
00182                 virtual ~merger_base() {
00183                         for (int i = 0; i < Nsources; i++) {
00184                                 delete dls (i);
00185                                 delete zdls (i);
00186                         }
00187                         if (DBG) delete dbg_file;
00188                 };
00190 
00193 
00195                 virtual void merge () {
00196                         validate();
00197 
00198                         
00199                         bool OK = true;
00200                         for (int i = 0;i < mpdfs.length(); i++) {
00201                                 OK &= (rvzs (i)._dsize() == 0); 
00202                                 OK &= (mpdfs (i)->_rvc()._dsize() == 0); 
00203                         }
00204 
00205                         if (OK) {
00206                                 mat lW = zeros (mpdfs.length(), eSmp._w().length());
00207 
00208                                 vec emptyvec (0);
00209                                 for (int i = 0; i < mpdfs.length(); i++) {
00210                                         for (int j = 0; j < eSmp._w().length(); j++) {
00211                                                 lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
00212                                         }
00213                                 }
00214 
00215                                 vec w_nn=merge_points (lW);
00216                                 vec wtmp = exp (w_nn-max(w_nn));
00217                                 
00218                                 eSmp._w() = wtmp / sum (wtmp);
00219                         } else {
00220                                 it_error ("Sources are not compatible - use merger_mix");
00221                         }
00222                 };
00223 
00224 
00226                 vec merge_points (mat &lW);
00227 
00228 
00231                 vec mean() const {
00232                         const Vec<double> &w = eSmp._w();
00233                         const Array<vec> &S = eSmp._samples();
00234                         vec tmp = zeros (dim);
00235                         for (int i = 0; i < Npoints; i++) {
00236                                 tmp += w (i) * S (i);
00237                         }
00238                         return tmp;
00239                 }
00240                 mat covariance() const {
00241                         const vec &w = eSmp._w();
00242                         const Array<vec> &S = eSmp._samples();
00243 
00244                         vec mea = mean();
00245 
00246 
00247 
00248                         mat Tmp = zeros (dim, dim);
00249                         for (int i = 0; i < Npoints; i++) {
00250                                 Tmp += w (i) * outer_product (S (i), S (i));
00251                         }
00252                         return Tmp -outer_product (mea, mea);
00253                 }
00254                 vec variance() const {
00255                         const vec &w = eSmp._w();
00256                         const Array<vec> &S = eSmp._samples();
00257 
00258                         vec tmp = zeros (dim);
00259                         for (int i = 0; i < Nsources; i++) {
00260                                 tmp += w (i) * pow (S (i), 2);
00261                         }
00262                         return tmp -pow (mean(), 2);
00263                 }
00265 
00268 
00270                 eEmp& _Smp() {return eSmp;}
00271 
00273                 void from_setting (const Setting& set) {
00274                         
00275                         
00276                         string meth_str;
00277                         UI::get<string> (meth_str, set, "method", UI::compulsory);
00278                         if (!strcmp (meth_str.c_str(), "arithmetic"))
00279                                 set_method (ARITHMETIC);
00280                         else {
00281                                 if (!strcmp (meth_str.c_str(), "geometric"))
00282                                         set_method (GEOMETRIC);
00283                                 else if (!strcmp (meth_str.c_str(), "lognormal")) {
00284                                         set_method (LOGNORMAL);
00285                                         set.lookupValue( "beta",beta);
00286                                 }
00287                         }
00288                         string dbg_file;
00289                         if (UI::get(dbg_file, set, "dbg_file"))
00290                                 set_debug_file(dbg_file);
00291                         
00292                 }
00293 
00294                 void validate() {
00295                         it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
00296                         it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
00297                         it_assert (isnamed(),"mergers must be named");
00298                 }
00300 };
00301 UIREGISTER(merger_base);
00302 
00303 class merger_mix : public merger_base
00304 {
00305         protected:
00307                 MixEF Mix;
00309                 int Ncoms;
00311                 double effss_coef;
00313                 int stop_niter;
00314                 
00316                 static const int DFLT_Ncoms;
00318                 static const double DFLT_effss_coef;
00319 
00320         public:
00323                 merger_mix () {};
00324                 merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
00326                 void set_sources (const Array<mpdf*> &S, bool own) {
00327                         merger_base::set_sources (S,own);
00328                         Nsources = S.length();
00329                 }
00331                 void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
00332                         Ncoms = Ncoms0;
00333                         effss_coef = effss_coef0;
00334                 }
00336 
00339 
00341                 void merge ();
00342 
00344                 vec sample () const { return Mix.posterior().sample();}
00346                 double evallog (const vec &dt) const {
00347                         vec dtf = ones (dt.length() + 1);
00348                         dtf.set_subvector (0, dt);
00349                         return Mix.logpred (dtf);
00350                 }
00352 
00356                 MixEF& _Mix() {return Mix;}
00358                 emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
00360                 void from_setting(const Setting& set){
00361                         merger_base::from_setting(set);
00362                         set.lookupValue("ncoms",Ncoms);
00363                         set.lookupValue("effss_coef",effss_coef);
00364                         set.lookupValue("stop_niter",stop_niter);
00365                 }
00366                 
00368 
00369 };
00370 UIREGISTER(merger_mix);
00371 
00372 }
00373 
00374 #endif // MER_H