00001
00013 #ifndef BDMBASE_H
00014 #define BDMBASE_H
00015
00016 #include <map>
00017
00018 #include "../itpp_ext.h"
00019 #include "../bdmroot.h"
00020 #include "../shared_ptr.h"
00021 #include "user_info.h"
00022
00023 using namespace libconfig;
00024 using namespace itpp;
00025 using namespace std;
00026
00027 namespace bdm
00028 {
00029
00030 typedef std::map<string, int> RVmap;
00031 extern ivec RV_SIZES;
00032 extern Array<string> RV_NAMES;
00033
00035 class str
00036 {
00037 public:
00039 ivec ids;
00041 ivec times;
00043 str(ivec ids0, ivec times0) : ids(ids0), times(times0) {
00044 it_assert_debug(times0.length() == ids0.length(), "Incompatible input");
00045 };
00046 };
00047
00086 class RV : public root
00087 {
00088 protected:
00090 int dsize;
00092 int len;
00094 ivec ids;
00096 ivec times;
00097
00098 private:
00100 void init(const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times);
00101 int init(const string &name, int size);
00102 public:
00105
00107 RV(const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times) { init(in_names, in_sizes, in_times); }
00108
00110 RV(const Array<std::string> &in_names, const ivec &in_sizes) { init(in_names, in_sizes, zeros_i(in_names.length())); }
00111
00113 RV(const Array<std::string> &in_names) { init(in_names, ones_i(in_names.length()), zeros_i(in_names.length())); }
00114
00116 RV() : dsize(0), len(0), ids(0), times(0) {}
00118 RV(string name, int sz, int tm = 0);
00120
00123
00125 friend std::ostream &operator<< (std::ostream &os, const RV &rv);
00126
00127 int _dsize() const { return dsize; }
00128
00130 int countsize() const;
00131 ivec cumsizes() const;
00132 int length() const {return len;}
00133 int id(int at) const {return ids(at);}
00134 int size(int at) const { return RV_SIZES(ids(at)); }
00135 int time(int at) const { return times(at); }
00136 std::string name(int at) const { return RV_NAMES(ids(at)); }
00137 void set_time(int at, int time0) { times(at) = time0; }
00139
00140
00141
00144
00146 ivec findself(const RV &rv2) const;
00148 bool equal(const RV &rv2) const;
00150 bool add(const RV &rv2);
00152 RV subt(const RV &rv2) const;
00154 RV subselect(const ivec &ind) const;
00155
00157 RV operator()(const ivec &ind) const { return subselect(ind); }
00158
00160 RV operator()(int di1, int di2) const;
00161
00163 void t(int delta);
00165
00168
00170 str tostr() const;
00173 ivec dataind(const RV &crv) const;
00176 void dataind(const RV &rv2, ivec &selfi, ivec &rv2i) const;
00178 int mint() const {return min(times);}
00180
00181
00196 void from_setting(const Setting &set);
00197
00198
00199
00201 static void clear_all();
00202 };
00203 UIREGISTER(RV);
00204
00206 RV concat(const RV &rv1, const RV &rv2);
00207
00209 extern RV RV0;
00210
00212
00213 class fnc : public root
00214 {
00215 protected:
00217 int dimy;
00218 public:
00220 fnc() {};
00222 virtual vec eval(const vec &cond) {
00223 return vec(0);
00224 };
00225
00227 virtual void condition(const vec &val) {};
00228
00230 int dimension() const {return dimy;}
00231 };
00232
00233 class mpdf;
00234
00236
00237 class epdf : public root
00238 {
00239 protected:
00241 int dim;
00243 RV rv;
00244
00245 public:
00257 epdf() : dim(0), rv() {};
00258 epdf(const epdf &e) : dim(e.dim), rv(e.rv) {};
00259 epdf(const RV &rv0):dim(rv0._dsize()) {set_rv(rv0);};
00260 void set_parameters(int dim0) {dim = dim0;}
00262
00265
00267 virtual vec sample() const {it_error("not implemneted"); return vec(0);};
00269 virtual mat sample_m(int N) const;
00272 virtual double evallog(const vec &val) const {it_error("not implemneted"); return 0.0;};
00274 virtual vec evallog_m(const mat &Val) const {
00275 vec x(Val.cols());
00276 for (int i = 0; i < Val.cols(); i++) {x(i) = evallog(Val.get_col(i)) ;}
00277 return x;
00278 }
00280 virtual vec evallog_m(const Array<vec> &Avec) const {
00281 vec x(Avec.size());
00282 for (int i = 0; i < Avec.size(); i++) {x(i) = evallog(Avec(i)) ;}
00283 return x;
00284 }
00286 virtual mpdf* condition(const RV &rv) const {it_warning("Not implemented"); return NULL;}
00287
00289 virtual epdf* marginal(const RV &rv) const {it_warning("Not implemented"); return NULL;}
00290
00292 virtual vec mean() const {it_error("not implemneted"); return vec(0);};
00293
00295 virtual vec variance() const {it_error("not implemneted"); return vec(0);};
00297 virtual void qbounds(vec &lb, vec &ub, double percentage = 0.95) const {
00298 vec mea = mean();
00299 vec std = sqrt(variance());
00300 lb = mea - 2 * std;
00301 ub = mea + 2 * std;
00302 };
00304
00310
00312 void set_rv(const RV &rv0) {rv = rv0; }
00314 bool isnamed() const {bool b = (dim == rv._dsize()); return b;}
00316 const RV& _rv() const {it_assert_debug(isnamed(), ""); return rv;}
00318
00321
00323 int dimension() const {return dim;}
00331 void from_setting(const Setting &set){
00332 RV* r = UI::build<RV>(set,"rv");
00333 if (r){
00334 set_rv(*r);
00335 delete r;
00336 }
00337 }
00338
00339 };
00340
00341
00343
00344
00345 class mpdf : public root
00346 {
00347 protected:
00349 int dimc;
00351 RV rvc;
00352
00353 private:
00355 shared_ptr<epdf> shep;
00356
00357 public:
00360
00361 mpdf():dimc(0), rvc() { }
00362
00363 mpdf(const mpdf &m):dimc(m.dimc), rvc(m.rvc), shep(m.shep) { }
00365
00368
00370 virtual vec samplecond(const vec &cond);
00371
00373 virtual mat samplecond_m(const vec &cond, int N);
00374
00376 virtual void condition(const vec &cond) {it_error("Not implemented");};
00377
00379 virtual double evallogcond(const vec &dt, const vec &cond);
00380
00382 virtual vec evallogcond_m(const mat &Dt, const vec &cond);
00383
00385 virtual vec evallogcond_m(const Array<vec> &Dt, const vec &cond);
00386
00389
00390 RV _rv() { return shep->_rv(); }
00391 RV _rvc() { it_assert_debug(isnamed(), ""); return rvc; }
00392 int dimension() { return shep->dimension(); }
00393 int dimensionc() { return dimc; }
00394
00395 epdf *e() { return shep.get(); }
00396
00397 void set_ep(shared_ptr<epdf> ep) { shep = ep; }
00398
00407 void from_setting(const Setting &set);
00409
00412 void set_rvc(const RV &rvc0) { rvc = rvc0; }
00413 void set_rv(const RV &rv0) { shep->set_rv(rv0); }
00414 bool isnamed() { return (shep->isnamed()) && (dimc == rvc._dsize()); }
00416 };
00417
00443 class datalink
00444 {
00445 protected:
00447 int downsize;
00448
00450 int upsize;
00451
00453 ivec v2v_up;
00454
00455 public:
00457 datalink():downsize(0), upsize(0) { }
00458 datalink(const RV &rv, const RV &rv_up) { set_connection(rv, rv_up); }
00459
00461 void set_connection(const RV &rv, const RV &rv_up) {
00462 downsize = rv._dsize();
00463 upsize = rv_up._dsize();
00464 v2v_up = rv.dataind(rv_up);
00465
00466 it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00467 }
00468
00470 void set_connection(int ds, int us, const ivec &upind) {
00471 downsize = ds;
00472 upsize = us;
00473 v2v_up = upind;
00474
00475 it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00476 }
00477
00479 vec pushdown(const vec &val_up) {
00480 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00481 return get_vec(val_up, v2v_up);
00482 }
00483
00485 void pushup(vec &val_up, const vec &val) {
00486 it_assert_debug(downsize == val.length(), "Wrong val");
00487 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00488 set_subvector(val_up, v2v_up, val);
00489 }
00490 };
00491
00493 class datalink_m2e: public datalink
00494 {
00495 protected:
00497 int condsize;
00498
00500 ivec v2c_up;
00501
00503 ivec v2c_lo;
00504
00505 public:
00507 datalink_m2e():condsize(0) { }
00508
00509 void set_connection(const RV &rv, const RV &rvc, const RV &rv_up) {
00510 datalink::set_connection(rv, rv_up);
00511 condsize = rvc._dsize();
00512
00513 rvc.dataind(rv_up, v2c_lo, v2c_up);
00514 }
00515
00517 vec get_cond(const vec &val_up) {
00518 vec tmp(condsize);
00519 set_subvector(tmp, v2c_lo, val_up(v2c_up));
00520 return tmp;
00521 }
00522
00523 void pushup_cond(vec &val_up, const vec &val, const vec &cond) {
00524 it_assert_debug(downsize == val.length(), "Wrong val");
00525 it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00526 set_subvector(val_up, v2v_up, val);
00527 set_subvector(val_up, v2c_up, cond);
00528 }
00529 };
00530
00533 class datalink_m2m: public datalink_m2e
00534 {
00535 protected:
00537 ivec c2c_up;
00539 ivec c2c_lo;
00540
00541 public:
00543 datalink_m2m() {};
00544 void set_connection(const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up) {
00545 datalink_m2e::set_connection(rv, rvc, rv_up);
00546
00547 rvc.dataind(rvc_up, c2c_lo, c2c_up);
00548 it_assert_debug(c2c_lo.length() + v2c_lo.length() == condsize, "cond is not fully given");
00549 }
00550
00552 vec get_cond(const vec &val_up, const vec &cond_up) {
00553 vec tmp(condsize);
00554 set_subvector(tmp, v2c_lo, val_up(v2c_up));
00555 set_subvector(tmp, c2c_lo, cond_up(c2c_up));
00556 return tmp;
00557 }
00559
00560 };
00561
00567 class logger : public root
00568 {
00569 protected:
00571 Array<RV> entries;
00573 Array<string> names;
00574 public:
00576 logger() : entries(0), names(0) {}
00577
00580 virtual int add(const RV &rv, string prefix = "") {
00581 int id;
00582 if (rv._dsize() > 0) {
00583 id = entries.length();
00584 names = concat(names, prefix);
00585 entries.set_length(id + 1, true);
00586 entries(id) = rv;
00587 }
00588 else { id = -1;}
00589 return id;
00590 }
00591
00593 virtual void logit(int id, const vec &v) = 0;
00595 virtual void logit(int id, const double &d) = 0;
00596
00598 virtual void step() = 0;
00599
00601 virtual void finalize() {};
00602
00604 virtual void init() {};
00605
00606 };
00607
00611 class mepdf : public mpdf {
00612 public:
00614 mepdf() { }
00615
00616 mepdf(shared_ptr<epdf> em) {
00617 set_ep(em);
00618 dimc = 0;
00619 }
00620
00622 void condition(const vec &cond);
00623
00631 void from_setting(const Setting &set);
00632 };
00633 UIREGISTER(mepdf);
00634
00638 class compositepdf {
00639 protected:
00641 Array<mpdf*> mpdfs;
00642 bool owning_mpdfs;
00643 public:
00644 compositepdf():mpdfs(0){};
00645 compositepdf(Array<mpdf*> A0, bool own=false){set_elements(A0,own);};
00646 void set_elements(Array<mpdf*> A0, bool own=false) {mpdfs=A0;owning_mpdfs=own;};
00648 RV getrv(bool checkoverlap = false);
00650 void setrvc(const RV &rv, RV &rvc);
00651 ~compositepdf(){if (owning_mpdfs) for(int i=0;i<mpdfs.length();i++){delete mpdfs(i);}};
00652 };
00653
00661 class DS : public root
00662 {
00663 protected:
00664 int dtsize;
00665 int utsize;
00667 RV Drv;
00669 RV Urv;
00671 int L_dt, L_ut;
00672 public:
00674 DS() : Drv(), Urv() {};
00676 virtual void getdata(vec &dt) {it_error("abstract class");};
00678 virtual void getdata(vec &dt, const ivec &indeces) {it_error("abstract class");};
00680 virtual void write(vec &ut) {it_error("abstract class");};
00682 virtual void write(vec &ut, const ivec &indeces) {it_error("abstract class");};
00683
00685 virtual void step() = 0;
00686
00688 virtual void log_add(logger &L) {
00689 it_assert_debug(dtsize == Drv._dsize(), "");
00690 it_assert_debug(utsize == Urv._dsize(), "");
00691
00692 L_dt = L.add(Drv, "");
00693 L_ut = L.add(Urv, "");
00694 }
00696 virtual void logit(logger &L) {
00697 vec tmp(Drv._dsize() + Urv._dsize());
00698 getdata(tmp);
00699
00700 L.logit(L_dt, tmp.left(Drv._dsize()));
00701
00702 L.logit(L_ut, tmp.mid(Drv._dsize(), Urv._dsize()));
00703 }
00705 virtual RV _drv() const {return concat(Drv, Urv);}
00707 const RV& _urv() const {return Urv;}
00709 virtual void set_drv(const RV &drv, const RV &urv) { Drv = drv; Urv = urv;}
00710 };
00711
00733 class BM : public root
00734 {
00735 protected:
00737 RV drv;
00739 double ll;
00741 bool evalll;
00742 public:
00745
00746 BM() : ll(0), evalll(true), LIDs(4), LFlags(4) {
00747 LIDs = -1;
00748 LFlags = 0;
00749 LFlags(0) = 1;
00750 };
00751 BM(const BM &B) : drv(B.drv), ll(B.ll), evalll(B.evalll) {}
00754 virtual BM* _copy_() const {return NULL;};
00756
00759
00763 virtual void bayes(const vec &dt) = 0;
00765 virtual void bayesB(const mat &Dt);
00768 virtual double logpred(const vec &dt) const {it_error("Not implemented"); return 0.0;}
00770 vec logpred_m(const mat &dt) const {vec tmp(dt.cols()); for (int i = 0; i < dt.cols(); i++) {tmp(i) = logpred(dt.get_col(i));} return tmp;}
00771
00773 virtual epdf* epredictor() const {it_error("Not implemented"); return NULL;};
00775 virtual mpdf* predictor() const {it_error("Not implemented"); return NULL;};
00777
00782
00784 RV rvc;
00786 const RV& _rvc() const {return rvc;}
00787
00789 virtual void condition(const vec &val) {it_error("Not implemented!");};
00790
00792
00793
00796
00797 const RV& _drv() const {return drv;}
00798 void set_drv(const RV &rv) {drv = rv;}
00799 void set_rv(const RV &rv) {const_cast<epdf&>(posterior()).set_rv(rv);}
00800 double _ll() const {return ll;}
00801 void set_evalll(bool evl0) {evalll = evl0;}
00802 virtual const epdf& posterior() const = 0;
00803 virtual const epdf* _e() const = 0;
00805
00808
00810 virtual void set_options(const string &opt) {
00811 LFlags(0) = 1;
00812 if (opt.find("logbounds") != string::npos) {LFlags(1) = 1; LFlags(2) = 1;}
00813 if (opt.find("logll") != string::npos) {LFlags(3) = 1;}
00814 }
00816 ivec LIDs;
00817
00819 ivec LFlags;
00821 virtual void log_add(logger &L, const string &name = "") {
00822
00823 RV r;
00824 if (posterior().isnamed()) {r = posterior()._rv();}
00825 else {r = RV("est", posterior().dimension());};
00826
00827
00828 if (LFlags(0)) LIDs(0) = L.add(r, name + "mean_");
00829 if (LFlags(1)) LIDs(1) = L.add(r, name + "lb_");
00830 if (LFlags(2)) LIDs(2) = L.add(r, name + "ub_");
00831 if (LFlags(3)) LIDs(3) = L.add(RV("ll", 1), name);
00832 }
00833 virtual void logit(logger &L) {
00834 L.logit(LIDs(0), posterior().mean());
00835 if (LFlags(1) || LFlags(2)) {
00836 vec ub, lb;
00837 posterior().qbounds(lb, ub);
00838 L.logit(LIDs(1), lb);
00839 L.logit(LIDs(2), ub);
00840 }
00841 if (LFlags(3)) L.logit(LIDs(3), ll);
00842 }
00844 };
00845
00846
00847 };
00848 #endif // BDMBASE_H