00001 
00013 #ifndef BDMBASE_H
00014 #define BDMBASE_H
00015 
00016 #include <map>
00017 
00018 #include "../itpp_ext.h"
00019 #include "../bdmroot.h"
00020 #include "../shared_ptr.h"
00021 #include "user_info.h"
00022 
00023 using namespace libconfig;
00024 using namespace itpp;
00025 using namespace std;
00026 
00027 namespace bdm
00028 {
00029 
00030 typedef std::map<string, int> RVmap;
00031 extern ivec RV_SIZES;
00032 extern Array<string> RV_NAMES;
00033 
00035 class str
00036 {
00037 public:
00039   ivec ids;
00041   ivec times;
00043   str(ivec ids0, ivec times0) : ids(ids0), times(times0) {
00044     it_assert_debug(times0.length() == ids0.length(), "Incompatible input");
00045   };
00046 };
00047 
00086 class RV : public root
00087 {
00088 protected:
00090   int dsize;
00092   int len;
00094   ivec ids;
00096   ivec times;
00097 
00098 private:
00100   void init(const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times);
00101   int init(const string &name, int size);
00102 public:
00105 
00107   RV(const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times) { init(in_names, in_sizes, in_times); }
00108 
00110   RV(const Array<std::string> &in_names, const ivec &in_sizes) { init(in_names, in_sizes, zeros_i(in_names.length())); }
00111 
00113   RV(const Array<std::string> &in_names) { init(in_names, ones_i(in_names.length()), zeros_i(in_names.length())); }
00114 
00116   RV() : dsize(0), len(0), ids(0), times(0) {}
00118   RV(string name, int sz, int tm = 0);
00120 
00123 
00125   friend std::ostream &operator<< (std::ostream &os, const RV &rv);
00126 
00127   int _dsize() const { return dsize; }
00128 
00130   int countsize() const;
00131   ivec cumsizes() const;
00132   int length() const {return len;}
00133   int id(int at) const {return ids(at);}
00134   int size(int at) const { return RV_SIZES(ids(at)); }
00135   int time(int at) const { return times(at); }
00136   std::string name(int at) const { return RV_NAMES(ids(at)); }
00137   void set_time(int at, int time0) { times(at) = time0; }
00139 
00140   
00141 
00144 
00146   ivec findself(const RV &rv2) const;
00148   bool equal(const RV &rv2) const;
00150   bool add(const RV &rv2);
00152   RV subt(const RV &rv2) const;
00154   RV subselect(const ivec &ind) const;
00155 
00157   RV operator()(const ivec &ind) const { return subselect(ind); }
00158 
00160   RV operator()(int di1, int di2) const;
00161 
00163   void t(int delta);
00165 
00168 
00170   str tostr() const;
00173   ivec dataind(const RV &crv) const;
00176   void dataind(const RV &rv2, ivec &selfi, ivec &rv2i) const;
00178   int mint() const {return min(times);}
00180 
00181   
00196   void from_setting(const Setting &set);
00197 
00198   
00199 
00201   static void clear_all();
00202 };
00203 UIREGISTER(RV);
00204 
00206 RV concat(const RV &rv1, const RV &rv2);
00207 
00209 extern RV RV0;
00210 
00212 
00213 class fnc : public root
00214 {
00215 protected:
00217   int dimy;
00218 public:
00220   fnc() {};
00222   virtual vec eval(const vec &cond) {
00223     return vec(0);
00224   };
00225 
00227   virtual void condition(const vec &val) {};
00228 
00230   int dimension() const {return dimy;}
00231 };
00232 
00233 class mpdf;
00234 
00236 
00237 class epdf : public root
00238 {
00239 protected:
00241   int dim;
00243   RV rv;
00244 
00245 public:
00257   epdf() : dim(0), rv() {};
00258   epdf(const epdf &e) : dim(e.dim), rv(e.rv) {};
00259   epdf(const RV &rv0):dim(rv0._dsize()) {set_rv(rv0);};
00260   void set_parameters(int dim0) {dim = dim0;}
00262 
00265 
00267   virtual vec sample() const {it_error("not implemneted"); return vec(0);};
00269   virtual mat sample_m(int N) const;
00272   virtual double evallog(const vec &val) const {it_error("not implemneted"); return 0.0;};
00274   virtual vec evallog_m(const mat &Val) const {
00275           vec x(Val.cols());
00276           for (int i = 0; i < Val.cols(); i++) {x(i) = evallog(Val.get_col(i)) ;}
00277           return x;
00278   }
00280   virtual vec evallog_m(const Array<vec> &Avec) const {
00281           vec x(Avec.size());
00282           for (int i = 0; i < Avec.size(); i++) {x(i) = evallog(Avec(i)) ;}
00283           return x;
00284   }
00286   virtual mpdf* condition(const RV &rv) const  {it_warning("Not implemented"); return NULL;}
00287 
00289   virtual epdf* marginal(const RV &rv) const {it_warning("Not implemented"); return NULL;}
00290 
00292   virtual vec mean() const {it_error("not implemneted"); return vec(0);};
00293 
00295   virtual vec variance() const {it_error("not implemneted"); return vec(0);};
00297   virtual void qbounds(vec &lb, vec &ub, double percentage = 0.95) const {
00298     vec mea = mean();
00299     vec std = sqrt(variance());
00300     lb = mea - 2 * std;
00301     ub = mea + 2 * std;
00302   };
00304 
00310 
00312   void set_rv(const RV &rv0) {rv = rv0; }   
00314   bool isnamed() const {bool b = (dim == rv._dsize()); return b;}
00316   const RV& _rv() const {it_assert_debug(isnamed(), ""); return rv;}
00318 
00321 
00323   int dimension() const {return dim;}
00331   void from_setting(const Setting &set){
00332           RV* r = UI::build<RV>(set,"rv");
00333           if (r){
00334                   set_rv(*r); 
00335                   delete r;
00336           }
00337   }
00338 
00339 };
00340 
00341 
00343 
00344 
00345 class mpdf : public root
00346 {
00347 protected:
00349   int dimc;
00351   RV rvc;
00352 
00353 private:
00355   shared_ptr<epdf> shep;
00356 
00357 public:
00360 
00361     mpdf():dimc(0), rvc() { }
00362 
00363     mpdf(const mpdf &m):dimc(m.dimc), rvc(m.rvc), shep(m.shep) { }
00365 
00368 
00370     virtual vec samplecond(const vec &cond);
00371 
00373     virtual mat samplecond_m(const vec &cond, int N);
00374 
00376     virtual void condition(const vec &cond) {it_error("Not implemented");};
00377 
00379     virtual double evallogcond(const vec &dt, const vec &cond);
00380 
00382     virtual vec evallogcond_m(const mat &Dt, const vec &cond);
00383 
00385     virtual vec evallogcond_m(const Array<vec> &Dt, const vec &cond);
00386 
00389 
00390     RV _rv() { return shep->_rv(); }
00391     RV _rvc() { it_assert_debug(isnamed(), ""); return rvc; }
00392     int dimension() { return shep->dimension(); }
00393     int dimensionc() { return dimc; }
00394 
00395     epdf *e() { return shep.get(); }
00396 
00397     void set_ep(shared_ptr<epdf> ep) { shep = ep; }
00398 
00407     void from_setting(const Setting &set);
00409 
00412     void set_rvc(const RV &rvc0) { rvc = rvc0; }
00413     void set_rv(const RV &rv0) { shep->set_rv(rv0); }
00414     bool isnamed() { return (shep->isnamed()) && (dimc == rvc._dsize()); }
00416 };
00417 
00443 class datalink
00444 {
00445 protected:
00447   int downsize;
00448 
00450   int upsize;
00451 
00453   ivec v2v_up;
00454 
00455 public:
00457   datalink():downsize(0), upsize(0) { }
00458   datalink(const RV &rv, const RV &rv_up) { set_connection(rv, rv_up); }
00459 
00461   void set_connection(const RV &rv, const RV &rv_up) {
00462     downsize = rv._dsize();
00463     upsize = rv_up._dsize();
00464     v2v_up = rv.dataind(rv_up);
00465 
00466     it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00467   }
00468 
00470   void set_connection(int ds, int us, const ivec &upind) {
00471     downsize = ds;
00472     upsize = us;
00473     v2v_up = upind;
00474 
00475     it_assert_debug(v2v_up.length() == downsize, "rv is not fully in rv_up");
00476   }
00477 
00479   vec pushdown(const vec &val_up) {
00480     it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00481     return get_vec(val_up, v2v_up);
00482   }
00483 
00485   void pushup(vec &val_up, const vec &val) {
00486     it_assert_debug(downsize == val.length(), "Wrong val");
00487     it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00488     set_subvector(val_up, v2v_up, val);
00489   }
00490 };
00491 
00493 class datalink_m2e: public datalink
00494 {
00495 protected:
00497   int condsize;
00498 
00500   ivec v2c_up;
00501 
00503   ivec v2c_lo;
00504 
00505 public:
00507   datalink_m2e():condsize(0) { }
00508 
00509   void set_connection(const RV &rv, const RV &rvc, const RV &rv_up) {
00510     datalink::set_connection(rv, rv_up);
00511     condsize = rvc._dsize();
00512     
00513     rvc.dataind(rv_up, v2c_lo, v2c_up);
00514   }
00515 
00517   vec get_cond(const vec &val_up) {
00518     vec tmp(condsize);
00519     set_subvector(tmp, v2c_lo, val_up(v2c_up));
00520     return tmp;
00521   }
00522 
00523   void pushup_cond(vec &val_up, const vec &val, const vec &cond) {
00524     it_assert_debug(downsize == val.length(), "Wrong val");
00525     it_assert_debug(upsize == val_up.length(), "Wrong val_up");
00526     set_subvector(val_up, v2v_up, val);
00527     set_subvector(val_up, v2c_up, cond);
00528   }
00529 };
00530 
00533 class datalink_m2m: public datalink_m2e
00534 {
00535 protected:
00537   ivec c2c_up;
00539   ivec c2c_lo;
00540 
00541 public:
00543   datalink_m2m() {};
00544   void set_connection(const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up) {
00545     datalink_m2e::set_connection(rv, rvc, rv_up);
00546     
00547     rvc.dataind(rvc_up, c2c_lo, c2c_up);
00548     it_assert_debug(c2c_lo.length() + v2c_lo.length() == condsize, "cond is not fully given");
00549   }
00550 
00552   vec get_cond(const vec &val_up, const vec &cond_up) {
00553     vec tmp(condsize);
00554     set_subvector(tmp, v2c_lo, val_up(v2c_up));
00555     set_subvector(tmp, c2c_lo, cond_up(c2c_up));
00556     return tmp;
00557   }
00559 
00560 };
00561 
00567 class logger : public root
00568 {
00569 protected:
00571   Array<RV> entries;
00573   Array<string> names;
00574 public:
00576   logger() : entries(0), names(0) {}
00577 
00580   virtual int add(const RV &rv, string prefix = "") {
00581     int id;
00582     if (rv._dsize() > 0) {
00583       id = entries.length();
00584       names = concat(names, prefix); 
00585       entries.set_length(id + 1, true);
00586       entries(id) = rv;
00587     }
00588     else { id = -1;}
00589     return id; 
00590   }
00591 
00593   virtual void logit(int id, const vec &v) = 0;
00595   virtual void logit(int id, const double &d) = 0;
00596 
00598   virtual void step() = 0;
00599 
00601   virtual void finalize() {};
00602 
00604   virtual void init() {};
00605 
00606 };
00607 
00611 class mepdf : public mpdf {
00612 public:
00614     mepdf() { }
00615 
00616     mepdf(shared_ptr<epdf> em) {
00617         set_ep(em);
00618         dimc = 0;
00619     }
00620 
00622     void condition(const vec &cond);
00623 
00631     void from_setting(const Setting &set);
00632 };
00633 UIREGISTER(mepdf);
00634 
00638 class compositepdf {
00639 protected:
00641   Array<mpdf*> mpdfs;
00642   bool owning_mpdfs;
00643 public:
00644         compositepdf():mpdfs(0){};
00645         compositepdf(Array<mpdf*> A0, bool own=false){set_elements(A0,own);};
00646         void set_elements(Array<mpdf*> A0, bool own=false) {mpdfs=A0;owning_mpdfs=own;};
00648   RV getrv(bool checkoverlap = false);
00650   void setrvc(const RV &rv, RV &rvc);
00651   ~compositepdf(){if (owning_mpdfs) for(int i=0;i<mpdfs.length();i++){delete mpdfs(i);}};
00652 };
00653 
00661 class DS : public root
00662 {
00663 protected:
00664   int dtsize;
00665   int utsize;
00667   RV Drv;
00669   RV Urv; 
00671   int L_dt, L_ut;
00672 public:
00674   DS() : Drv(), Urv() {};
00676   virtual void getdata(vec &dt) {it_error("abstract class");};
00678   virtual void getdata(vec &dt, const ivec &indeces) {it_error("abstract class");};
00680   virtual void write(vec &ut) {it_error("abstract class");};
00682   virtual void write(vec &ut, const ivec &indeces) {it_error("abstract class");};
00683 
00685   virtual void step() = 0;
00686 
00688   virtual void log_add(logger &L) {
00689     it_assert_debug(dtsize == Drv._dsize(), "");
00690     it_assert_debug(utsize == Urv._dsize(), "");
00691 
00692     L_dt = L.add(Drv, "");
00693     L_ut = L.add(Urv, "");
00694   }
00696   virtual void logit(logger &L) {
00697     vec tmp(Drv._dsize() + Urv._dsize());
00698     getdata(tmp);
00699     
00700     L.logit(L_dt, tmp.left(Drv._dsize()));
00701     
00702     L.logit(L_ut, tmp.mid(Drv._dsize(), Urv._dsize()));
00703   }
00705   virtual RV _drv() const {return concat(Drv, Urv);}
00707   const RV& _urv() const {return Urv;}
00709   virtual void set_drv(const  RV &drv, const RV &urv) { Drv = drv; Urv = urv;}
00710 };
00711 
00733 class BM : public root
00734 {
00735 protected:
00737   RV drv;
00739   double ll;
00741   bool evalll;
00742 public:
00745 
00746   BM() : ll(0), evalll(true), LIDs(4), LFlags(4) {
00747     LIDs = -1;
00748     LFlags = 0;
00749     LFlags(0) = 1;
00750   };
00751   BM(const BM &B) :  drv(B.drv), ll(B.ll), evalll(B.evalll) {}
00754   virtual BM* _copy_() const {return NULL;};
00756 
00759 
00763   virtual void bayes(const vec &dt) = 0;
00765   virtual void bayesB(const mat &Dt);
00768   virtual double logpred(const vec &dt) const {it_error("Not implemented"); return 0.0;}
00770   vec logpred_m(const mat &dt) const {vec tmp(dt.cols()); for (int i = 0; i < dt.cols(); i++) {tmp(i) = logpred(dt.get_col(i));} return tmp;}
00771 
00773   virtual epdf* epredictor() const {it_error("Not implemented"); return NULL;};
00775   virtual mpdf* predictor() const {it_error("Not implemented"); return NULL;};
00777 
00782 
00784   RV rvc;
00786   const RV& _rvc() const {return rvc;}
00787 
00789   virtual void condition(const vec &val) {it_error("Not implemented!");};
00790 
00792 
00793 
00796 
00797   const RV& _drv() const {return drv;}
00798   void set_drv(const RV &rv) {drv = rv;}
00799   void set_rv(const RV &rv) {const_cast<epdf&>(posterior()).set_rv(rv);}
00800   double _ll() const {return ll;}
00801   void set_evalll(bool evl0) {evalll = evl0;}
00802   virtual const epdf& posterior() const = 0;
00803   virtual const epdf* _e() const = 0;
00805 
00808 
00810   virtual void set_options(const string &opt) {
00811     LFlags(0) = 1;
00812     if (opt.find("logbounds") != string::npos) {LFlags(1) = 1; LFlags(2) = 1;}
00813     if (opt.find("logll") != string::npos) {LFlags(3) = 1;}
00814   }
00816   ivec LIDs;
00817 
00819   ivec LFlags;
00821   virtual void log_add(logger &L, const string &name = "") {
00822     
00823     RV r;
00824     if (posterior().isnamed()) {r = posterior()._rv();}
00825     else {r = RV("est", posterior().dimension());};
00826 
00827     
00828     if (LFlags(0)) LIDs(0) = L.add(r, name + "mean_");
00829     if (LFlags(1)) LIDs(1) = L.add(r, name + "lb_");
00830     if (LFlags(2)) LIDs(2) = L.add(r, name + "ub_");
00831     if (LFlags(3)) LIDs(3) = L.add(RV("ll", 1), name);    
00832   }
00833   virtual void logit(logger &L) {
00834     L.logit(LIDs(0), posterior().mean());
00835     if (LFlags(1) || LFlags(2)) {  
00836       vec ub, lb;
00837       posterior().qbounds(lb, ub);
00838       L.logit(LIDs(1), lb);
00839       L.logit(LIDs(2), ub);
00840     }
00841     if (LFlags(3)) L.logit(LIDs(3), ll);
00842   }
00844 };
00845 
00846 
00847 }; 
00848 #endif // BDMBASE_H