00001 
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015 
00016 
00017 #include "../estim/mixtures.h"
00018 
00019 namespace bdm
00020 {
00021 using std::string;
00022 
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025 
00045 class merger_base : public compositepdf, public epdf
00046 {
00047         protected:
00049                 Array<datalink_m2e*> dls;
00051                 Array<RV> rvzs;
00053                 Array<datalink_m2e*> zdls;
00055                 int Npoints;
00057                 int Nsources;
00058 
00060                 MERGER_METHOD METHOD;
00062                 static const MERGER_METHOD DFLT_METHOD;
00063                 
00065                 double beta;
00067                 static const double DFLT_beta;
00068                 
00070                 eEmp eSmp;
00071 
00073                 bool DBG;
00074 
00076                 it_file* dbg_file;
00077         public:
00080 
00082                 merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
00084                 merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
00086                 void set_sources (const Array<mpdf*> &Sources, bool own) {
00087                         compositepdf::set_elements (Sources,own);
00088                         Nsources=mpdfs.length();
00089                         
00090                         dls.set_size (Sources.length());
00091                         rvzs.set_size (Sources.length());
00092                         zdls.set_size (Sources.length());
00093 
00094                         rv = getrv ( false);
00095                         RV rvc; setrvc (rv, rvc);  
00096                         
00097                         rv.add (rvc);
00098                         
00099                         dim = rv._dsize();
00100 
00101                         
00102                         RV xytmp;
00103                         for (int i = 0;i < mpdfs.length();i++) {
00104                                 
00105                                 dls (i) = new datalink_m2e;
00106                                 dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
00107 
00108                                 
00109                                 xytmp = mpdfs (i)->_rv();
00110                                 xytmp.add (mpdfs (i)->_rvc());
00111                                 
00112                                 rvzs (i) = rv.subt (xytmp);
00113                                 
00114                                 zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
00115                         };
00116                 }
00118                 void set_support (const Array<vec> &XYZ, const int dimsize) {
00119                         set_support(XYZ,dimsize*ones_i(XYZ.length()));
00120                 }
00122                 void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
00123                         int dim = XYZ.length();  
00124                         Npoints = prod (gridsize); 
00125                         eSmp.set_parameters (Npoints, false);
00126                         Array<vec> &samples = eSmp._samples();
00127                         eSmp._w() = ones (Npoints) / Npoints; 
00128                         
00129                         ivec ind = zeros_i (dim);      
00130                         vec smpi (dim);            
00131                         vec steps =zeros(dim);            
00132                         
00133                         for (int j = 0; j < dim; j++) {
00134                                 smpi (j) = XYZ (j) (0);  
00135                                 it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
00136                                 steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
00137                         }
00138                         
00139                         for (int i = 0; i < Npoints; i++) {
00140                                 
00141                                 samples(i) = smpi; 
00142                                 
00143                                 for (int j = 0;j < dim;j++) {
00144                                         if (ind (j) == gridsize (j) - 1) { 
00145                                                 ind (j) = 0; 
00146                                                 smpi(j) = XYZ(j)(0);
00147                                                 
00148                                                 if (i<Npoints-1) {
00149                                                         ind (j + 1) ++; 
00150                                                         smpi(j+1) += steps(j+1);
00151                                                         break;
00152                                                 }
00153                                                 
00154                                         } else {
00155                                                 ind (j) ++; 
00156                                                 smpi(j) +=steps(j);
00157                                                 break;
00158                                         }
00159                                 }
00160                         }
00161                 }
00163                 void set_debug_file (const string fname) {
00164                         if (DBG) delete dbg_file;
00165                         dbg_file = new it_file (fname);
00166                         if (dbg_file) DBG = true;
00167                 }
00169                 void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
00170                         METHOD = MTH;
00171                         beta = beta0;
00172                 }
00174                 void set_support (const epdf &overall, int N) {
00175                         eSmp.set_statistics (&overall, N);
00176                         Npoints = N;
00177                 }
00178 
00180                 virtual ~merger_base() {
00181                         for (int i = 0; i < Nsources; i++) {
00182                                 delete dls (i);
00183                                 delete zdls (i);
00184                         }
00185                         if (DBG) delete dbg_file;
00186                 };
00188 
00191 
00193                 virtual void merge () {
00194                         validate();
00195 
00196                         
00197                         bool OK = true;
00198                         for (int i = 0;i < mpdfs.length(); i++) {
00199                                 OK &= (rvzs (i)._dsize() == 0); 
00200                                 OK &= (mpdfs (i)->_rvc()._dsize() == 0); 
00201                         }
00202 
00203                         if (OK) {
00204                                 mat lW = zeros (mpdfs.length(), eSmp._w().length());
00205 
00206                                 vec emptyvec (0);
00207                                 for (int i = 0; i < mpdfs.length(); i++) {
00208                                         for (int j = 0; j < eSmp._w().length(); j++) {
00209                                                 lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
00210                                         }
00211                                 }
00212 
00213                                 vec w_nn=merge_points (lW);
00214                                 vec wtmp = exp (w_nn-max(w_nn));
00215                                 
00216                                 eSmp._w() = wtmp / sum (wtmp);
00217                         } else {
00218                                 it_error ("Sources are not compatible - use merger_mix");
00219                         }
00220                 };
00221 
00222 
00224                 vec merge_points (mat &lW);
00225 
00226 
00229                 vec mean() const {
00230                         const Vec<double> &w = eSmp._w();
00231                         const Array<vec> &S = eSmp._samples();
00232                         vec tmp = zeros (dim);
00233                         for (int i = 0; i < Npoints; i++) {
00234                                 tmp += w (i) * S (i);
00235                         }
00236                         return tmp;
00237                 }
00238                 mat covariance() const {
00239                         const vec &w = eSmp._w();
00240                         const Array<vec> &S = eSmp._samples();
00241 
00242                         vec mea = mean();
00243 
00244                         cout << sum (w) << "," << w*w << endl;
00245 
00246                         mat Tmp = zeros (dim, dim);
00247                         for (int i = 0; i < Npoints; i++) {
00248                                 Tmp += w (i) * outer_product (S (i), S (i));
00249                         }
00250                         return Tmp -outer_product (mea, mea);
00251                 }
00252                 vec variance() const {
00253                         const vec &w = eSmp._w();
00254                         const Array<vec> &S = eSmp._samples();
00255 
00256                         vec tmp = zeros (dim);
00257                         for (int i = 0; i < Nsources; i++) {
00258                                 tmp += w (i) * pow (S (i), 2);
00259                         }
00260                         return tmp -pow (mean(), 2);
00261                 }
00263 
00266 
00268                 eEmp& _Smp() {return eSmp;}
00269 
00271                 void from_setting (const Setting& set) {
00272                         
00273                         
00274                         string meth_str;
00275                         UI::get<string> (meth_str, set, "method");
00276                         if (!strcmp (meth_str.c_str(), "arithmetic"))
00277                                 set_method (ARITHMETIC);
00278                         else {
00279                                 if (!strcmp (meth_str.c_str(), "geometric"))
00280                                         set_method (GEOMETRIC);
00281                                 else if (!strcmp (meth_str.c_str(), "lognormal")) {
00282                                         set_method (LOGNORMAL);
00283                                         set.lookupValue( "beta",beta);
00284                                 }
00285                         }
00286                         if (set.exists("dbg_file")){ 
00287                                 string dbg_file;
00288                                 UI::get<string> (dbg_file, set, "dbg_file");
00289                                 set_debug_file(dbg_file);
00290                         }
00291                         
00292                 }
00293 
00294                 void validate() {
00295                         it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
00296                         it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
00297                         it_assert (isnamed(),"mergers must be named");
00298                 }
00300 };
00301 UIREGISTER(merger_base);
00302 
00303 class merger_mix : public merger_base
00304 {
00305         protected:
00307                 MixEF Mix;
00309                 int Ncoms;
00311                 double effss_coef;
00313                 int stop_niter;
00314                 
00316                 static const int DFLT_Ncoms;
00318                 static const double DFLT_effss_coef;
00319 
00320         public:
00323                 merger_mix () {};
00324                 merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
00326                 void set_sources (const Array<mpdf*> &S, bool own) {
00327                         merger_base::set_sources (S,own);
00328                         Nsources = S.length();
00329                 }
00331                 void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
00332                         Ncoms = Ncoms0;
00333                         effss_coef = effss_coef0;
00334                 }
00336 
00339 
00341                 void merge ();
00342 
00344                 vec sample () const { return Mix.posterior().sample();}
00346                 double evallog (const vec &dt) const {
00347                         vec dtf = ones (dt.length() + 1);
00348                         dtf.set_subvector (0, dt);
00349                         return Mix.logpred (dtf);
00350                 }
00352 
00356                 MixEF& _Mix() {return Mix;}
00358                 emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
00360                 void from_setting(const Setting& set){
00361                         merger_base::from_setting(set);
00362                         set.lookupValue("ncoms",Ncoms);
00363                         set.lookupValue("effss_coef",effss_coef);
00364                         set.lookupValue("stop_niter",stop_niter);
00365                 }
00366                 
00368 
00369 };
00370 UIREGISTER(merger_mix);
00371 
00372 }
00373 
00374 #endif // MER_H