00001
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015
00016
00017 #include "../estim/mixtures.h"
00018
00019 namespace bdm
00020 {
00021 using std::string;
00022
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025
00045 class merger_base : public compositepdf, public epdf
00046 {
00047 protected:
00049 Array<datalink_m2e*> dls;
00051 Array<RV> rvzs;
00053 Array<datalink_m2e*> zdls;
00055 int Npoints;
00057 int Nsources;
00058
00060 MERGER_METHOD METHOD;
00062 double beta;
00063
00065 eEmp eSmp;
00066
00068 bool DBG;
00069
00071 it_file* dbg_file;
00072 public:
00075
00077 merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
00079 merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
00081 void set_sources (const Array<mpdf*> &Sources, bool own) {
00082 compositepdf::set_elements (Sources,own);
00083 Nsources=mpdfs.length();
00084
00085 dls.set_size (Sources.length());
00086 rvzs.set_size (Sources.length());
00087 zdls.set_size (Sources.length());
00088
00089 rv = getrv ( false);
00090 RV rvc; setrvc (rv, rvc);
00091
00092 rv.add (rvc);
00093
00094 dim = rv._dsize();
00095
00096
00097 RV xytmp;
00098 for (int i = 0;i < mpdfs.length();i++) {
00099
00100 dls (i) = new datalink_m2e;
00101 dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
00102
00103
00104 xytmp = mpdfs (i)->_rv();
00105 xytmp.add (mpdfs (i)->_rvc());
00106
00107 rvzs (i) = rv.subt (xytmp);
00108
00109 zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
00110 };
00111 }
00113 void set_support (const Array<vec> &XYZ, const int dimsize) {
00114 set_support(XYZ,dimsize*ones_i(XYZ.length()));
00115 }
00117 void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
00118 int dim = XYZ.length();
00119 Npoints = prod (gridsize);
00120 eSmp.set_parameters (Npoints, false);
00121 Array<vec> &samples = eSmp._samples();
00122 eSmp._w() = ones (Npoints) / Npoints;
00123
00124 ivec ind = zeros_i (dim);
00125 vec smpi (dim);
00126 vec steps =zeros(dim);
00127
00128 for (int j = 0; j < dim; j++) {
00129 smpi (j) = XYZ (j) (0);
00130 it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
00131 steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
00132 }
00133
00134 for (int i = 0; i < Npoints; i++) {
00135
00136 samples(i) = smpi;
00137
00138 for (int j = 0;j < dim;j++) {
00139 if (ind (j) == gridsize (j) - 1) {
00140 ind (j) = 0;
00141 smpi(j) = XYZ(j)(0);
00142
00143 if (i<Npoints-1) {
00144 ind (j + 1) ++;
00145 smpi(j+1) += steps(j+1);
00146 break;
00147 }
00148
00149 } else {
00150 ind (j) ++;
00151 smpi(j) +=steps(j);
00152 break;
00153 }
00154 }
00155 }
00156 }
00158 void set_debug_file (const string fname) {
00159 if (DBG) delete dbg_file;
00160 dbg_file = new it_file (fname);
00161 if (dbg_file) DBG = true;
00162 }
00164 void set_method (MERGER_METHOD MTH, double beta0 = 0.0) {
00165 METHOD = MTH;
00166 beta = beta0;
00167 }
00169 void set_support (const epdf &overall, int N) {
00170 eSmp.set_statistics (&overall, N);
00171 Npoints = N;
00172 }
00173
00175 virtual ~merger_base() {
00176 for (int i = 0; i < Nsources; i++) {
00177 delete dls (i);
00178 delete zdls (i);
00179 }
00180 if (DBG) delete dbg_file;
00181 };
00183
00186
00188 virtual void merge () {
00189 validate();
00190
00191
00192 bool OK = true;
00193 for (int i = 0;i < mpdfs.length(); i++) {
00194 OK &= (rvzs (i)._dsize() == 0);
00195 OK &= (mpdfs (i)->_rvc()._dsize() == 0);
00196 }
00197
00198 if (OK) {
00199 mat lW = zeros (mpdfs.length(), eSmp._w().length());
00200
00201 vec emptyvec (0);
00202 for (int i = 0; i < mpdfs.length(); i++) {
00203 for (int j = 0; j < eSmp._w().length(); j++) {
00204 lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
00205 }
00206 }
00207
00208 vec w_nn=merge_points (lW);
00209 vec wtmp = exp (w_nn-max(w_nn));
00210
00211 eSmp._w() = wtmp / sum (wtmp);
00212 } else {
00213 it_error ("Sources are not compatible - use merger_mix");
00214 }
00215 };
00216
00217
00219 vec merge_points (mat &lW);
00220
00221
00224 vec mean() const {
00225 const Vec<double> &w = eSmp._w();
00226 const Array<vec> &S = eSmp._samples();
00227 vec tmp = zeros (dim);
00228 for (int i = 0; i < Npoints; i++) {
00229 tmp += w (i) * S (i);
00230 }
00231 return tmp;
00232 }
00233 mat covariance() const {
00234 const vec &w = eSmp._w();
00235 const Array<vec> &S = eSmp._samples();
00236
00237 vec mea = mean();
00238
00239 cout << sum (w) << "," << w*w << endl;
00240
00241 mat Tmp = zeros (dim, dim);
00242 for (int i = 0; i < Npoints; i++) {
00243 Tmp += w (i) * outer_product (S (i), S (i));
00244 }
00245 return Tmp -outer_product (mea, mea);
00246 }
00247 vec variance() const {
00248 const vec &w = eSmp._w();
00249 const Array<vec> &S = eSmp._samples();
00250
00251 vec tmp = zeros (dim);
00252 for (int i = 0; i < Nsources; i++) {
00253 tmp += w (i) * pow (S (i), 2);
00254 }
00255 return tmp -pow (mean(), 2);
00256 }
00258
00261
00263 eEmp& _Smp() {return eSmp;}
00264
00266 void from_setting (const Setting& set) {
00267
00268
00269 string meth_str;
00270 UI::get<string> (meth_str, set, "method");
00271 if (!strcmp (meth_str.c_str(), "arithmetic"))
00272 set_method (ARITHMETIC);
00273 else {
00274 if (!strcmp (meth_str.c_str(), "geometric"))
00275 set_method (GEOMETRIC);
00276 else if (!strcmp (meth_str.c_str(), "lognormal")) {
00277 set_method (LOGNORMAL);
00278 set.lookupValue( "beta",beta);
00279 }
00280 }
00281 if (set.exists("dbg_file")){
00282 string dbg_file;
00283 UI::get<string> (dbg_file, set, "dbg_file");
00284 set_debug_file(dbg_file);
00285 }
00286
00287 }
00288
00289 void validate() {
00290 it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
00291 it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
00292 it_assert (isnamed(),"mergers must be named");
00293 }
00295 };
00296 UIREGISTER(merger_base);
00297
00298 class merger_mix : public merger_base
00299 {
00300 protected:
00302 MixEF Mix;
00304 int Ncoms;
00306 double effss_coef;
00308 int stop_niter;
00309
00310 public:
00313 merger_mix () {};
00314 merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
00316 void set_sources (const Array<mpdf*> &S, bool own) {
00317 merger_base::set_sources (S,own);
00318 Nsources = S.length();
00319 }
00321 void set_parameters (int Ncoms0 = 10, double effss_coef0 = 0.5) {
00322 Ncoms = Ncoms0;
00323 effss_coef = effss_coef0;
00324 }
00326
00329
00331 void merge ();
00332
00334 vec sample () const { return Mix.posterior().sample();}
00336 double evallog (const vec &dt) const {
00337 vec dtf = ones (dt.length() + 1);
00338 dtf.set_subvector (0, dt);
00339 return Mix.logpred (dtf);
00340 }
00342
00346 MixEF& _Mix() {return Mix;}
00348 emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
00350 void from_setting(const Setting& set){
00351 merger_base::from_setting(set);
00352 set.lookupValue("ncoms",Ncoms);
00353 set.lookupValue("effss_coef",effss_coef);
00354 set.lookupValue("stop_niter",stop_niter);
00355 }
00356
00358
00359 };
00360 UIREGISTER(merger_mix);
00361
00362 }
00363
00364 #endif // MER_H