00001
00013 #ifndef MERGER_H
00014 #define MERGER_H
00015
00016
00017 #include "../estim/mixtures.h"
00018
00019 namespace bdm
00020 {
00021 using std::string;
00022
00024 enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
00025
00045 class merger_base : public compositepdf, public epdf
00046 {
00047 protected:
00049 Array<datalink_m2e*> dls;
00051 Array<RV> rvzs;
00053 Array<datalink_m2e*> zdls;
00055 int Npoints;
00057 int Nsources;
00058
00060 MERGER_METHOD METHOD;
00062 static const MERGER_METHOD DFLT_METHOD;
00063
00065 double beta;
00067 static const double DFLT_beta;
00068
00070 eEmp eSmp;
00071
00073 bool DBG;
00074
00076 it_file* dbg_file;
00077 public:
00080
00082 merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
00084 merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
00086 void set_sources (const Array<mpdf*> &Sources, bool own) {
00087 compositepdf::set_elements (Sources,own);
00088 Nsources=mpdfs.length();
00089
00090 dls.set_size (Sources.length());
00091 rvzs.set_size (Sources.length());
00092 zdls.set_size (Sources.length());
00093
00094 rv = getrv ( false);
00095 RV rvc; setrvc (rv, rvc);
00096
00097 rv.add (rvc);
00098
00099 dim = rv._dsize();
00100
00101
00102 RV xytmp;
00103 for (int i = 0;i < mpdfs.length();i++) {
00104
00105 dls (i) = new datalink_m2e;
00106 dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
00107
00108
00109 xytmp = mpdfs (i)->_rv();
00110 xytmp.add (mpdfs (i)->_rvc());
00111
00112 rvzs (i) = rv.subt (xytmp);
00113
00114 zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
00115 };
00116 }
00118 void set_support (const Array<vec> &XYZ, const int dimsize) {
00119 set_support(XYZ,dimsize*ones_i(XYZ.length()));
00120 }
00122 void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
00123 int dim = XYZ.length();
00124 Npoints = prod (gridsize);
00125 eSmp.set_parameters (Npoints, false);
00126 Array<vec> &samples = eSmp._samples();
00127 eSmp._w() = ones (Npoints) / Npoints;
00128
00129 ivec ind = zeros_i (dim);
00130 vec smpi (dim);
00131 vec steps =zeros(dim);
00132
00133 for (int j = 0; j < dim; j++) {
00134 smpi (j) = XYZ (j) (0);
00135 it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
00136 steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
00137 }
00138
00139 for (int i = 0; i < Npoints; i++) {
00140
00141 samples(i) = smpi;
00142
00143 for (int j = 0;j < dim;j++) {
00144 if (ind (j) == gridsize (j) - 1) {
00145 ind (j) = 0;
00146 smpi(j) = XYZ(j)(0);
00147
00148 if (i<Npoints-1) {
00149 ind (j + 1) ++;
00150 smpi(j+1) += steps(j+1);
00151 break;
00152 }
00153
00154 } else {
00155 ind (j) ++;
00156 smpi(j) +=steps(j);
00157 break;
00158 }
00159 }
00160 }
00161 }
00163 void set_debug_file (const string fname) {
00164 if (DBG) delete dbg_file;
00165 dbg_file = new it_file (fname);
00166 if (dbg_file) DBG = true;
00167 }
00169 void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
00170 METHOD = MTH;
00171 beta = beta0;
00172 }
00174 void set_support (const epdf &overall, int N) {
00175 eSmp.set_statistics (&overall, N);
00176 Npoints = N;
00177 }
00178
00180 virtual ~merger_base() {
00181 for (int i = 0; i < Nsources; i++) {
00182 delete dls (i);
00183 delete zdls (i);
00184 }
00185 if (DBG) delete dbg_file;
00186 };
00188
00191
00193 virtual void merge () {
00194 validate();
00195
00196
00197 bool OK = true;
00198 for (int i = 0;i < mpdfs.length(); i++) {
00199 OK &= (rvzs (i)._dsize() == 0);
00200 OK &= (mpdfs (i)->_rvc()._dsize() == 0);
00201 }
00202
00203 if (OK) {
00204 mat lW = zeros (mpdfs.length(), eSmp._w().length());
00205
00206 vec emptyvec (0);
00207 for (int i = 0; i < mpdfs.length(); i++) {
00208 for (int j = 0; j < eSmp._w().length(); j++) {
00209 lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
00210 }
00211 }
00212
00213 vec w_nn=merge_points (lW);
00214 vec wtmp = exp (w_nn-max(w_nn));
00215
00216 eSmp._w() = wtmp / sum (wtmp);
00217 } else {
00218 it_error ("Sources are not compatible - use merger_mix");
00219 }
00220 };
00221
00222
00224 vec merge_points (mat &lW);
00225
00226
00229 vec mean() const {
00230 const Vec<double> &w = eSmp._w();
00231 const Array<vec> &S = eSmp._samples();
00232 vec tmp = zeros (dim);
00233 for (int i = 0; i < Npoints; i++) {
00234 tmp += w (i) * S (i);
00235 }
00236 return tmp;
00237 }
00238 mat covariance() const {
00239 const vec &w = eSmp._w();
00240 const Array<vec> &S = eSmp._samples();
00241
00242 vec mea = mean();
00243
00244 cout << sum (w) << "," << w*w << endl;
00245
00246 mat Tmp = zeros (dim, dim);
00247 for (int i = 0; i < Npoints; i++) {
00248 Tmp += w (i) * outer_product (S (i), S (i));
00249 }
00250 return Tmp -outer_product (mea, mea);
00251 }
00252 vec variance() const {
00253 const vec &w = eSmp._w();
00254 const Array<vec> &S = eSmp._samples();
00255
00256 vec tmp = zeros (dim);
00257 for (int i = 0; i < Nsources; i++) {
00258 tmp += w (i) * pow (S (i), 2);
00259 }
00260 return tmp -pow (mean(), 2);
00261 }
00263
00266
00268 eEmp& _Smp() {return eSmp;}
00269
00271 void from_setting (const Setting& set) {
00272
00273
00274 string meth_str;
00275 UI::get<string> (meth_str, set, "method");
00276 if (!strcmp (meth_str.c_str(), "arithmetic"))
00277 set_method (ARITHMETIC);
00278 else {
00279 if (!strcmp (meth_str.c_str(), "geometric"))
00280 set_method (GEOMETRIC);
00281 else if (!strcmp (meth_str.c_str(), "lognormal")) {
00282 set_method (LOGNORMAL);
00283 set.lookupValue( "beta",beta);
00284 }
00285 }
00286 if (set.exists("dbg_file")){
00287 string dbg_file;
00288 UI::get<string> (dbg_file, set, "dbg_file");
00289 set_debug_file(dbg_file);
00290 }
00291
00292 }
00293
00294 void validate() {
00295 it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
00296 it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
00297 it_assert (isnamed(),"mergers must be named");
00298 }
00300 };
00301 UIREGISTER(merger_base);
00302
00303 class merger_mix : public merger_base
00304 {
00305 protected:
00307 MixEF Mix;
00309 int Ncoms;
00311 double effss_coef;
00313 int stop_niter;
00314
00316 static const int DFLT_Ncoms;
00318 static const double DFLT_effss_coef;
00319
00320 public:
00323 merger_mix () {};
00324 merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
00326 void set_sources (const Array<mpdf*> &S, bool own) {
00327 merger_base::set_sources (S,own);
00328 Nsources = S.length();
00329 }
00331 void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
00332 Ncoms = Ncoms0;
00333 effss_coef = effss_coef0;
00334 }
00336
00339
00341 void merge ();
00342
00344 vec sample () const { return Mix.posterior().sample();}
00346 double evallog (const vec &dt) const {
00347 vec dtf = ones (dt.length() + 1);
00348 dtf.set_subvector (0, dt);
00349 return Mix.logpred (dtf);
00350 }
00352
00356 MixEF& _Mix() {return Mix;}
00358 emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
00360 void from_setting(const Setting& set){
00361 merger_base::from_setting(set);
00362 set.lookupValue("ncoms",Ncoms);
00363 set.lookupValue("effss_coef",effss_coef);
00364 set.lookupValue("stop_niter",stop_niter);
00365 }
00366
00368
00369 };
00370 UIREGISTER(merger_mix);
00371
00372 }
00373
00374 #endif // MER_H