00001
00013 #ifndef BDMBASE_H
00014 #define BDMBASE_H
00015
00016 #include <map>
00017
00018 #include "../itpp_ext.h"
00019 #include "../bdmroot.h"
00020 #include "user_info.h"
00021
00022
00023 using namespace libconfig;
00024 using namespace itpp;
00025 using namespace std;
00026
00027 namespace bdm
00028 {
00029
00030 typedef std::map<string, int> RVmap;
00031 extern ivec RV_SIZES;
00032 extern Array<string> RV_NAMES;
00033
00035 class str
00036 {
00037 public:
00039 ivec ids;
00041 ivec times;
00043 str(ivec ids0, ivec times0) : ids(ids0), times(times0) {
00044 it_assert_debug(times0.length() == ids0.length(), "Incompatible input");
00045 };
00046 };
00047
00086 class RV : public root
00087 {
00088 protected:
00090 int dsize;
00092 int len;
00094 ivec ids;
00096 ivec times;
00097
00098 private:
00100 void init(Array<std::string> in_names, ivec in_sizes, ivec in_times);
00101 int init(const string &name, int size);
00102 public:
00105
00107 RV(Array<std::string> in_names, ivec in_sizes, ivec in_times) {init(in_names, in_sizes, in_times);};
00109 RV(Array<std::string> in_names, ivec in_sizes) {init(in_names, in_sizes, zeros_i(in_names.length()));};
00111 RV(Array<std::string> in_names) {init(in_names, ones_i(in_names.length()), zeros_i(in_names.length()));}
00113 RV() : dsize(0), len(0), ids(0), times(0) {};
00115 RV(string name, int sz, int tm = 0);
00117
00120
00122 friend std::ostream &operator<< (std::ostream &os, const RV &rv);
00123 int _dsize() const {return dsize;} ;
00125 int countsize() const;
00126 ivec cumsizes() const;
00127 int length() const {return len;} ;
00128 int id(int at) const {return ids(at);};
00129 int size(int at) const {return RV_SIZES(ids(at));};
00130 int time(int at) const {return times(at);};
00131 std::string name(int at) const {return RV_NAMES(ids(at));};
00132 void set_time(int at, int time0) {times(at) = time0;};
00134
00135
00136
00139
00141 ivec findself(const RV &rv2) const;
00143 bool equal(const RV &rv2) const;
00145 bool add(const RV &rv2);
00147 RV subt(const RV &rv2) const;
00149 RV subselect(const ivec &ind) const;
00151 RV operator()(const ivec &ind) const {return subselect(ind);};
00153 RV operator()(int di1, int di2) const {
00154 ivec sz = cumsizes();
00155 int i1 = 0;
00156 while (sz(i1) < di1) i1++;
00157 int i2 = i1;
00158 while (sz(i2) < di2) i2++;
00159 return subselect(linspace(i1, i2));
00160 };
00162 void t(int delta);
00164
00167
00169 str tostr() const;
00172 ivec dataind(const RV &crv) const;
00175 void dataind(const RV &rv2, ivec &selfi, ivec &rv2i) const;
00177 int mint() const {return min(times);};
00179
00180
00195 void from_setting(const Setting &set);
00196
00197
00198 };
00199 UIREGISTER(RV);
00200
00202 RV concat(const RV &rv1, const RV &rv2);
00203
00205 extern RV RV0;
00206
00208
00209 class fnc : public root
00210 {
00211 protected:
00213 int dimy;
00214 public:
00216 fnc() {};
00218 virtual vec eval(const vec &cond) {
00219 return vec(0);
00220 };
00221
00223 virtual void condition(const vec &val) {};
00224
00226 int dimension() const {return dimy;}
00227 };
00228
00229 class mpdf;
00230
00232
00233 class epdf : public root
00234 {
00235 protected:
00237 int dim;
00239 RV rv;
00240
00241 public:
00253 epdf() : dim(0), rv() {};
00254 epdf(const epdf &e) : dim(e.dim), rv(e.rv) {};
00255 epdf(const RV &rv0) {set_rv(rv0);};
00256 void set_parameters(int dim0) {dim = dim0;}
00258
00261
00263 virtual vec sample() const {it_error("not implemneted"); return vec(0);};
00265 virtual mat sample_m(int N) const;
00267 virtual double evallog(const vec &val) const {it_error("not implemneted"); return 0.0;};
00269 virtual vec evallog_m(const mat &Val) const {
00270 vec x(Val.cols());
00271 for (int i = 0; i < Val.cols(); i++) {x(i) = evallog(Val.get_col(i)) ;}
00272 return x;
00273 }
00275 virtual vec evallog_m(const Array<vec> &Avec) const {
00276 vec x(Avec.size());
00277 for (int i = 0; i < Avec.size(); i++) {x(i) = evallog(Avec(i)) ;}
00278 return x;
00279 }
00281 virtual mpdf* condition(const RV &rv) const {it_warning("Not implemented"); return NULL;}
00282
00284 virtual epdf* marginal(const RV &rv) const {it_warning("Not implemented"); return NULL;}
00285
00287 virtual vec mean() const {it_error("not implemneted"); return vec(0);};
00288
00290 virtual vec variance() const {it_error("not implemneted"); return vec(0);};
00292 virtual void qbounds(vec &lb, vec &ub, double percentage = 0.95) const {
00293 vec mea = mean();
00294 vec std = sqrt(variance());
00295 lb = mea - 2 * std;
00296 ub = mea + 2 * std;
00297 };
00299
00305
00307 void set_rv(const RV &rv0) {rv = rv0; }
00309 bool isnamed() const {bool b = (dim == rv._dsize()); return b;}
00311 const RV& _rv() const {it_assert_debug(isnamed(), ""); return rv;}
00313
00316
00318 int dimension() const {return dim;}
00326 void from_setting(const Setting &set){
00327 if (set.exists("rv")){
00328 RV* r = UI::build<RV>(set,"rv");
00329 set_rv(*r);
00330 delete r;
00331 }
00332 }
00333
00334 };
00335
00336
00338
00339
00340 class mpdf : public root
00341 {
00342 protected:
00344 int dimc;
00346 RV rvc;
00348 epdf* ep;
00349 public:
00352
00353 mpdf() : dimc(0), rvc() {};
00355 mpdf(const mpdf &m) : dimc(m.dimc), rvc(m.rvc) {};
00357
00360
00362 virtual vec samplecond(const vec &cond) {
00363 this->condition(cond);
00364 vec temp = ep->sample();
00365 return temp;
00366 };
00368 virtual mat samplecond_m(const vec &cond, int N) {
00369 this->condition(cond);
00370 mat temp(ep->dimension(), N);
00371 vec smp(ep->dimension());
00372 for (int i = 0; i < N; i++) {smp = ep->sample() ; temp.set_col(i, smp);}
00373 return temp;
00374 };
00376 virtual void condition(const vec &cond) {it_error("Not implemented");};
00377
00379 virtual double evallogcond(const vec &dt, const vec &cond) {
00380 double tmp;
00381 this->condition(cond);
00382 tmp = ep->evallog(dt);
00383 it_assert_debug(std::isfinite(tmp), "Infinite value");
00384 return tmp;
00385 };
00386
00388 virtual vec evallogcond_m(const mat &Dt, const vec &cond) {this->condition(cond); return ep->evallog_m(Dt);};
00390 virtual vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {this->condition(cond); return ep->evallog_m(Dt);};
00391
00394
00395 RV _rv() {return ep->_rv();}
00396 RV _rvc() {it_assert_debug(isnamed(), ""); return rvc;}
00397 int dimension() {return ep->dimension();}
00398 int dimensionc() {return dimc;}
00399 epdf& _epdf() {return *ep;}
00400 epdf* _e() {return ep;}
00409 void from_setting(const Setting &set){
00410 if (set.exists("rv")){
00411 RV* r = UI::build<RV>(set,"rv");
00412 set_rv(*r);
00413 delete r;
00414 }
00415 if (set.exists("rvc")){
00416 RV* r = UI::build<RV>(set,"rvc");
00417 set_rvc(*r);
00418 delete r;
00419 }
00420 }
00422
00425 void set_rvc(const RV &rvc0) {rvc = rvc0;}
00426 void set_rv(const RV &rv0) {ep->set_rv(rv0);}
00427 bool isnamed() {return (ep->isnamed()) && (dimc == rvc._dsize());}
00429 };
00430
00456 class datalink
00457 {
00458 protected:
00460 int downsize;
00462 int upsize;
00464 ivec v2v_up;
00465 public:
00467 datalink() {};
00468 datalink(const RV &rv, const RV &rv_up) {set_connection(rv, rv_up);};
00470 void set_connection(const RV &rv, const RV &rv_up) {
00471 downsize = rv._dsize();
00472 upsize = rv_up._dsize();
00473 v2v_up = (rv.dataind(rv_up));
00474
00475 it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00476 }
00478 void set_connection(int ds, int us, const ivec &upind) {
00479 downsize = ds;
00480 upsize = us;
00481 v2v_up = upind;
00482
00483 it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00484 }
00486 vec pushdown(const vec &val_up) {
00487 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00488 return get_vec(val_up, v2v_up);
00489 }
00491 void pushup(vec &val_up, const vec &val) {
00492 it_assert_debug(downsize == val.length(), "Wrong val");
00493 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00494 set_subvector(val_up, v2v_up, val);
00495 }
00496 };
00497
00499 class datalink_m2e: public datalink
00500 {
00501 protected:
00503 int condsize;
00505 ivec v2c_up;
00507 ivec v2c_lo;
00508
00509 public:
00510 datalink_m2e() {};
00512 void set_connection(const RV &rv, const RV &rvc, const RV &rv_up) {
00513 datalink::set_connection(rv, rv_up);
00514 condsize = rvc._dsize();
00515
00516 rvc.dataind(rv_up, v2c_lo, v2c_up);
00517 }
00519 vec get_cond(const vec &val_up) {
00520 vec tmp(condsize);
00521 set_subvector(tmp, v2c_lo, val_up(v2c_up));
00522 return tmp;
00523 }
00524 void pushup_cond(vec &val_up, const vec &val, const vec &cond) {
00525 it_assert_debug(downsize == val.length(), "Wrong val");
00526 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00527 set_subvector(val_up, v2v_up, val);
00528 set_subvector(val_up, v2c_up, cond);
00529 }
00530 };
00533 class datalink_m2m: public datalink_m2e
00534 {
00535 protected:
00537 ivec c2c_up;
00539 ivec c2c_lo;
00540 public:
00542 datalink_m2m() {};
00543 void set_connection(const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up) {
00544 datalink_m2e::set_connection(rv, rvc, rv_up);
00545
00546 rvc.dataind(rvc_up, c2c_lo, c2c_up);
00547 it_assert_debug(c2c_lo.length() + v2c_lo.length() == condsize, "cond is not fully given");
00548 }
00550 vec get_cond(const vec &val_up, const vec &cond_up) {
00551 vec tmp(condsize);
00552 set_subvector(tmp, v2c_lo, val_up(v2c_up));
00553 set_subvector(tmp, c2c_lo, cond_up(c2c_up));
00554 return tmp;
00555 }
00557
00558 };
00559
00565 class logger : public root
00566 {
00567 protected:
00569 Array<RV> entries;
00571 Array<string> names;
00572 public:
00574 logger() : entries(0), names(0) {}
00575
00578 virtual int add(const RV &rv, string prefix = "") {
00579 int id;
00580 if (rv._dsize() > 0) {
00581 id = entries.length();
00582 names = concat(names, prefix);
00583 entries.set_length(id + 1, true);
00584 entries(id) = rv;
00585 }
00586 else { id = -1;}
00587 return id;
00588 }
00589
00591 virtual void logit(int id, const vec &v) = 0;
00593 virtual void logit(int id, const double &d) = 0;
00594
00596 virtual void step() = 0;
00597
00599 virtual void finalize() {};
00600
00602 virtual void init() {};
00603
00604 };
00605
00609 class mepdf : public mpdf {
00610 bool owning_ep;
00611 public:
00613 mepdf(){};
00614 mepdf ( epdf* em, bool owning_ep0=false ) :mpdf ( ) {ep= em ;owning_ep=owning_ep0;dimc=0;};
00615 mepdf (const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*>( em );};
00616 void condition ( const vec &cond ) {}
00617 ~mepdf(){if (owning_ep) delete ep;}
00625 void from_setting(const Setting &set){
00626 epdf* e = UI::build<epdf>(set,"epdf");
00627 ep= e;
00628 owning_ep=true;
00629 }
00630 };
00631 UIREGISTER(mepdf);
00632
00636 class compositepdf {
00637 protected:
00639 Array<mpdf*> mpdfs;
00640 bool owning_mpdfs;
00641 public:
00642 compositepdf():mpdfs(0){};
00643 compositepdf(Array<mpdf*> A0, bool own=false){set_elements(A0,own);};
00644 void set_elements(Array<mpdf*> A0, bool own=false) {mpdfs=A0;owning_mpdfs=own;};
00646 RV getrv(bool checkoverlap = false);
00648 void setrvc(const RV &rv, RV &rvc);
00649 ~compositepdf(){if (owning_mpdfs) for(int i=0;i<mpdfs.length();i++){delete mpdfs(i);}};
00650 };
00651
00659 class DS : public root
00660 {
00661 protected:
00662 int dtsize;
00663 int utsize;
00665 RV Drv;
00667 RV Urv;
00669 int L_dt, L_ut;
00670 public:
00672 DS() : Drv(), Urv() {};
00674 virtual void getdata(vec &dt) {it_error("abstract class");};
00676 virtual void getdata(vec &dt, const ivec &indeces) {it_error("abstract class");};
00678 virtual void write(vec &ut) {it_error("abstract class");};
00680 virtual void write(vec &ut, const ivec &indeces) {it_error("abstract class");};
00681
00683 virtual void step() = 0;
00684
00686 virtual void log_add(logger &L) {
00687 it_assert_debug(dtsize == Drv._dsize(), "");
00688 it_assert_debug(utsize == Urv._dsize(), "");
00689
00690 L_dt = L.add(Drv, "");
00691 L_ut = L.add(Urv, "");
00692 }
00694 virtual void logit(logger &L) {
00695 vec tmp(Drv._dsize() + Urv._dsize());
00696 getdata(tmp);
00697
00698 L.logit(L_dt, tmp.left(Drv._dsize()));
00699
00700 L.logit(L_ut, tmp.mid(Drv._dsize(), Urv._dsize()));
00701 }
00703 virtual RV _drv() const {return concat(Drv, Urv);}
00705 const RV& _urv() const {return Urv;}
00707 virtual void set_drv(const RV &drv, const RV &urv) { Drv = drv; Urv = urv;}
00708 };
00709
00731 class BM : public root
00732 {
00733 protected:
00735 RV drv;
00737 double ll;
00739 bool evalll;
00740 public:
00743
00744 BM() : ll(0), evalll(true), LIDs(4), LFlags(4) {
00745 LIDs = -1;
00746 LFlags = 0;
00747 LFlags(0) = 1;
00748 };
00749 BM(const BM &B) : drv(B.drv), ll(B.ll), evalll(B.evalll) {}
00752 virtual BM* _copy_() const {return NULL;};
00754
00757
00761 virtual void bayes(const vec &dt) = 0;
00763 virtual void bayesB(const mat &Dt);
00766 virtual double logpred(const vec &dt) const {it_error("Not implemented"); return 0.0;}
00768 vec logpred_m(const mat &dt) const {vec tmp(dt.cols()); for (int i = 0; i < dt.cols(); i++) {tmp(i) = logpred(dt.get_col(i));} return tmp;}
00769
00771 virtual epdf* epredictor() const {it_error("Not implemented"); return NULL;};
00773 virtual mpdf* predictor() const {it_error("Not implemented"); return NULL;};
00775
00780
00782 RV rvc;
00784 const RV& _rvc() const {return rvc;}
00785
00787 virtual void condition(const vec &val) {it_error("Not implemented!");};
00788
00790
00791
00794
00795 const RV& _drv() const {return drv;}
00796 void set_drv(const RV &rv) {drv = rv;}
00797 void set_rv(const RV &rv) {const_cast<epdf&>(posterior()).set_rv(rv);}
00798 double _ll() const {return ll;}
00799 void set_evalll(bool evl0) {evalll = evl0;}
00800 virtual const epdf& posterior() const = 0;
00801 virtual const epdf* _e() const = 0;
00803
00806
00808 virtual void set_options(const string &opt) {
00809 LFlags(0) = 1;
00810 if (opt.find("logbounds") != string::npos) {LFlags(1) = 1; LFlags(2) = 1;}
00811 if (opt.find("logll") != string::npos) {LFlags(3) = 1;}
00812 }
00814 ivec LIDs;
00815
00817 ivec LFlags;
00819 virtual void log_add(logger &L, const string &name = "") {
00820
00821 RV r;
00822 if (posterior().isnamed()) {r = posterior()._rv();}
00823 else {r = RV("est", posterior().dimension());};
00824
00825
00826 if (LFlags(0)) LIDs(0) = L.add(r, name + "mean_");
00827 if (LFlags(1)) LIDs(1) = L.add(r, name + "lb_");
00828 if (LFlags(2)) LIDs(2) = L.add(r, name + "ub_");
00829 if (LFlags(3)) LIDs(3) = L.add(RV("ll", 1), name);
00830 }
00831 virtual void logit(logger &L) {
00832 L.logit(LIDs(0), posterior().mean());
00833 if (LFlags(1) || LFlags(2)) {
00834 vec ub, lb;
00835 posterior().qbounds(lb, ub);
00836 L.logit(LIDs(1), lb);
00837 L.logit(LIDs(2), ub);
00838 }
00839 if (LFlags(3)) L.logit(LIDs(3), ll);
00840 }
00842 };
00843
00844
00845 };
00846 #endif // BDMBASE_H