00001
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015
00016
00017 #include "../estim/mixtures.h"
00018
00019 namespace bdm
00020 {
00021 using std::string;
00022
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025
00045 class merger_base : public compositepdf, public epdf
00046 {
00047 protected:
00049 Array<datalink_m2e*> dls;
00051 Array<RV> rvzs;
00053 Array<datalink_m2e*> zdls;
00055 int Npoints;
00057 int Nsources;
00058
00060 MERGER_METHOD METHOD;
00062 static const MERGER_METHOD DFLT_METHOD;
00063
00065 double beta;
00067 static const double DFLT_beta;
00068
00070 eEmp eSmp;
00071
00073 bool DBG;
00074
00076 it_file* dbg_file;
00077 public:
00080
00082 merger_base () : compositepdf() {DBG = false;dbg_file = NULL;}
00083
00085 merger_base (const Array<mpdf*> &S, bool own=false);
00086
00088 void set_sources (const Array<mpdf*> &Sources, bool own) {
00089 compositepdf::set_elements (Sources,own);
00090 Nsources=mpdfs.length();
00091
00092 dls.set_size (Sources.length());
00093 rvzs.set_size (Sources.length());
00094 zdls.set_size (Sources.length());
00095
00096 rv = getrv ( false);
00097 RV rvc; setrvc (rv, rvc);
00098
00099 rv.add (rvc);
00100
00101 dim = rv._dsize();
00102
00103
00104 RV xytmp;
00105 for (int i = 0;i < mpdfs.length();i++) {
00106
00107 dls (i) = new datalink_m2e;
00108 dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
00109
00110
00111 xytmp = mpdfs (i)->_rv();
00112 xytmp.add (mpdfs (i)->_rvc());
00113
00114 rvzs (i) = rv.subt (xytmp);
00115
00116 zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
00117 };
00118 }
00120 void set_support (const Array<vec> &XYZ, const int dimsize) {
00121 set_support(XYZ,dimsize*ones_i(XYZ.length()));
00122 }
00124 void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
00125 int dim = XYZ.length();
00126 Npoints = prod (gridsize);
00127 eSmp.set_parameters (Npoints, false);
00128 Array<vec> &samples = eSmp._samples();
00129 eSmp._w() = ones (Npoints) / Npoints;
00130
00131 ivec ind = zeros_i (dim);
00132 vec smpi (dim);
00133 vec steps =zeros(dim);
00134
00135 for (int j = 0; j < dim; j++) {
00136 smpi (j) = XYZ (j) (0);
00137 it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
00138 steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
00139 }
00140
00141 for (int i = 0; i < Npoints; i++) {
00142
00143 samples(i) = smpi;
00144
00145 for (int j = 0;j < dim;j++) {
00146 if (ind (j) == gridsize (j) - 1) {
00147 ind (j) = 0;
00148 smpi(j) = XYZ(j)(0);
00149
00150 if (i<Npoints-1) {
00151 ind (j + 1) ++;
00152 smpi(j+1) += steps(j+1);
00153 break;
00154 }
00155
00156 } else {
00157 ind (j) ++;
00158 smpi(j) +=steps(j);
00159 break;
00160 }
00161 }
00162 }
00163 }
00165 void set_debug_file (const string fname) {
00166 if (DBG) delete dbg_file;
00167 dbg_file = new it_file (fname);
00168 if (dbg_file) DBG = true;
00169 }
00171 void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
00172 METHOD = MTH;
00173 beta = beta0;
00174 }
00176 void set_support (const epdf &overall, int N) {
00177 eSmp.set_statistics (&overall, N);
00178 Npoints = N;
00179 }
00180
00182 virtual ~merger_base() {
00183 for (int i = 0; i < Nsources; i++) {
00184 delete dls (i);
00185 delete zdls (i);
00186 }
00187 if (DBG) delete dbg_file;
00188 };
00190
00193
00195 virtual void merge () {
00196 validate();
00197
00198
00199 bool OK = true;
00200 for (int i = 0;i < mpdfs.length(); i++) {
00201 OK &= (rvzs (i)._dsize() == 0);
00202 OK &= (mpdfs (i)->_rvc()._dsize() == 0);
00203 }
00204
00205 if (OK) {
00206 mat lW = zeros (mpdfs.length(), eSmp._w().length());
00207
00208 vec emptyvec (0);
00209 for (int i = 0; i < mpdfs.length(); i++) {
00210 for (int j = 0; j < eSmp._w().length(); j++) {
00211 lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
00212 }
00213 }
00214
00215 vec w_nn=merge_points (lW);
00216 vec wtmp = exp (w_nn-max(w_nn));
00217
00218 eSmp._w() = wtmp / sum (wtmp);
00219 } else {
00220 it_error ("Sources are not compatible - use merger_mix");
00221 }
00222 };
00223
00224
00226 vec merge_points (mat &lW);
00227
00228
00231 vec mean() const {
00232 const Vec<double> &w = eSmp._w();
00233 const Array<vec> &S = eSmp._samples();
00234 vec tmp = zeros (dim);
00235 for (int i = 0; i < Npoints; i++) {
00236 tmp += w (i) * S (i);
00237 }
00238 return tmp;
00239 }
00240 mat covariance() const {
00241 const vec &w = eSmp._w();
00242 const Array<vec> &S = eSmp._samples();
00243
00244 vec mea = mean();
00245
00246
00247
00248 mat Tmp = zeros (dim, dim);
00249 for (int i = 0; i < Npoints; i++) {
00250 Tmp += w (i) * outer_product (S (i), S (i));
00251 }
00252 return Tmp -outer_product (mea, mea);
00253 }
00254 vec variance() const {
00255 const vec &w = eSmp._w();
00256 const Array<vec> &S = eSmp._samples();
00257
00258 vec tmp = zeros (dim);
00259 for (int i = 0; i < Nsources; i++) {
00260 tmp += w (i) * pow (S (i), 2);
00261 }
00262 return tmp -pow (mean(), 2);
00263 }
00265
00268
00270 eEmp& _Smp() {return eSmp;}
00271
00273 void from_setting (const Setting& set) {
00274
00275
00276 string meth_str;
00277 UI::get<string> (meth_str, set, "method", UI::compulsory);
00278 if (!strcmp (meth_str.c_str(), "arithmetic"))
00279 set_method (ARITHMETIC);
00280 else {
00281 if (!strcmp (meth_str.c_str(), "geometric"))
00282 set_method (GEOMETRIC);
00283 else if (!strcmp (meth_str.c_str(), "lognormal")) {
00284 set_method (LOGNORMAL);
00285 set.lookupValue( "beta",beta);
00286 }
00287 }
00288 string dbg_file;
00289 if (UI::get(dbg_file, set, "dbg_file"))
00290 set_debug_file(dbg_file);
00291
00292 }
00293
00294 void validate() {
00295 it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
00296 it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
00297 it_assert (isnamed(),"mergers must be named");
00298 }
00300 };
00301 UIREGISTER(merger_base);
00302
00303 class merger_mix : public merger_base
00304 {
00305 protected:
00307 MixEF Mix;
00309 int Ncoms;
00311 double effss_coef;
00313 int stop_niter;
00314
00316 static const int DFLT_Ncoms;
00318 static const double DFLT_effss_coef;
00319
00320 public:
00323 merger_mix () {};
00324 merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
00326 void set_sources (const Array<mpdf*> &S, bool own) {
00327 merger_base::set_sources (S,own);
00328 Nsources = S.length();
00329 }
00331 void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
00332 Ncoms = Ncoms0;
00333 effss_coef = effss_coef0;
00334 }
00336
00339
00341 void merge ();
00342
00344 vec sample () const { return Mix.posterior().sample();}
00346 double evallog (const vec &dt) const {
00347 vec dtf = ones (dt.length() + 1);
00348 dtf.set_subvector (0, dt);
00349 return Mix.logpred (dtf);
00350 }
00352
00356 MixEF& _Mix() {return Mix;}
00358 emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
00360 void from_setting(const Setting& set){
00361 merger_base::from_setting(set);
00362 set.lookupValue("ncoms",Ncoms);
00363 set.lookupValue("effss_coef",effss_coef);
00364 set.lookupValue("stop_niter",stop_niter);
00365 }
00366
00368
00369 };
00370 UIREGISTER(merger_mix);
00371
00372 }
00373
00374 #endif // MER_H