00001
00013 #ifndef BM_H
00014 #define BM_H
00015
00016 #include <map>
00017
00018 #include "../itpp_ext.h"
00019 #include "../libconfig/libconfig.h++"
00020
00021
00022 using namespace libconfig;
00023
00024 using namespace itpp;
00025 using namespace std;
00026
00027 namespace bdm {
00028
00030
00031 class bdmroot {
00032 public:
00034 virtual ~bdmroot()
00035 {
00036 }
00037
00039 virtual string ToString()
00040 {
00041 return "";
00042 }
00043
00045 virtual void from_setting( const Setting &root )
00046 {
00047 }
00048
00050 virtual void to_setting( Setting &root ) const
00051 {
00052 }
00053 };
00054
00055 typedef std::map<string, int> RVmap;
00056 extern ivec RV_SIZES;
00057 extern Array<string> RV_NAMES;
00058
00060 class str {
00061 public:
00063 ivec ids;
00065 ivec times;
00067 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00068 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00069 };
00070 };
00071
00110 class RV :public bdmroot {
00111 protected:
00113 int dsize;
00115 int len;
00117 ivec ids;
00119 ivec times;
00120
00121 private:
00123 void init ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00124 int init ( const string &name, int size );
00125 public:
00128
00130 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times ) {init ( in_names,in_sizes,in_times );};
00132 RV ( Array<std::string> in_names, ivec in_sizes ) {init ( in_names,in_sizes,zeros_i ( in_names.length() ) );};
00134 RV ( Array<std::string> in_names ) {init ( in_names,ones_i ( in_names.length() ),zeros_i ( in_names.length() ) );}
00136 RV () :dsize ( 0 ),len ( 0 ),ids ( 0 ),times ( 0 ) {};
00138 RV ( string name, int sz, int tm=0 );
00140
00143
00145 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00146 int _dsize() const {return dsize;} ;
00148 int countsize() const;
00149 ivec cumsizes() const;
00150 int length() const {return len;} ;
00151 int id ( int at ) const{return ids ( at );};
00152 int size ( int at ) const {return RV_SIZES ( ids ( at ) );};
00153 int time ( int at ) const{return times ( at );};
00154 std::string name ( int at ) const {return RV_NAMES ( ids ( at ) );};
00155 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00157
00158
00159
00162
00164 ivec findself ( const RV &rv2 ) const;
00166 bool equal ( const RV &rv2 ) const;
00168 bool add ( const RV &rv2 );
00170 RV subt ( const RV &rv2 ) const;
00172 RV subselect ( const ivec &ind ) const;
00174 RV operator() ( const ivec &ind ) const {return subselect ( ind );};
00176 RV operator() ( int di1, int di2 ) const {
00177 ivec sz=cumsizes();
00178 int i1=0;
00179 while ( sz ( i1 ) <di1 ) i1++;
00180 int i2=i1;
00181 while ( sz ( i2 ) <di2 ) i2++;
00182 return subselect ( linspace ( i1,i2 ) );
00183 };
00185 void t ( int delta );
00187
00190
00192 str tostr() const;
00195 ivec dataind ( const RV &crv ) const;
00198 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00200 int mint () const {return min ( times );};
00202
00203 };
00204
00205
00207 RV concat ( const RV &rv1, const RV &rv2 );
00208
00210 extern RV RV0;
00211
00213
00214 class fnc :public bdmroot {
00215 protected:
00217 int dimy;
00218 public:
00220 fnc ( ) {};
00222 virtual vec eval ( const vec &cond ) {
00223 return vec ( 0 );
00224 };
00225
00227 virtual void condition ( const vec &val ) {};
00228
00230 int dimension() const{return dimy;}
00231 };
00232
00233 class mpdf;
00234
00236
00237 class epdf :public bdmroot {
00238 protected:
00240 int dim;
00242 RV rv;
00243
00244 public:
00256 epdf() :dim ( 0 ),rv ( ) {};
00257 epdf ( const epdf &e ) :dim ( e.dim ),rv ( e.rv ) {};
00258 epdf ( const RV &rv0 ) {set_rv ( rv0 );};
00259 void set_parameters ( int dim0 ) {dim=dim0;}
00261
00264
00266 virtual vec sample () const {it_error ( "not implemneted" );return vec ( 0 );};
00268 virtual mat sample_m ( int N ) const;
00270 virtual double evallog ( const vec &val ) const {it_error ( "not implemneted" );return 0.0;};
00272 virtual vec evallog_m ( const mat &Val ) const {
00273 vec x ( Val.cols() );
00274 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00275 return x;
00276 }
00278 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00280 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00282 virtual vec mean() const {it_error ( "not implemneted" );return vec ( 0 );};
00284 virtual vec variance() const {it_error ( "not implemneted" );return vec ( 0 );};
00286 virtual void qbounds ( vec &lb, vec &ub, double percentage=0.95 ) const {
00287 vec mea=mean(); vec std=sqrt ( variance() );
00288 lb = mea-2*std; ub=mea+2*std;
00289 };
00291
00297
00299 void set_rv ( const RV &rv0 ) {rv = rv0; }
00301 bool isnamed() const {bool b= ( dim==rv._dsize() );return b;}
00303 const RV& _rv() const {it_assert_debug ( isnamed(),"" ); return rv;}
00305
00308
00310 int dimension() const {return dim;}
00312
00313 };
00314
00315
00317
00318
00319 class mpdf : public bdmroot {
00320 protected:
00322 int dimc;
00324 RV rvc;
00326 epdf* ep;
00327 public:
00330
00331 mpdf ( ) :dimc ( 0 ),rvc ( ) {};
00333 mpdf ( const mpdf &m ) :dimc ( m.dimc ),rvc ( m.rvc ) {};
00335
00338
00340 virtual vec samplecond ( const vec &cond ) {
00341 this->condition ( cond );
00342 vec temp= ep->sample();
00343 return temp;
00344 };
00346 virtual mat samplecond_m ( const vec &cond, int N ) {
00347 this->condition ( cond );
00348 mat temp ( ep->dimension(),N ); vec smp ( ep->dimension() );
00349 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );}
00350 return temp;
00351 };
00353 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00354
00356 virtual double evallogcond ( const vec &dt, const vec &cond ) {
00357 double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00358 };
00359
00361 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00362
00365
00366 RV _rv() {return ep->_rv();}
00367 RV _rvc() {it_assert_debug ( isnamed(),"" ); return rvc;}
00368 int dimension() {return ep->dimension();}
00369 int dimensionc() {return dimc;}
00370 epdf& _epdf() {return *ep;}
00371 epdf* _e() {return ep;}
00373
00376 void set_rvc ( const RV &rvc0 ) {rvc=rvc0;}
00377 void set_rv ( const RV &rv0 ) {ep->set_rv ( rv0 );}
00378 bool isnamed() {return ( ep->isnamed() ) && ( dimc==rvc._dsize() );}
00380 };
00381
00407 class datalink {
00408 protected:
00410 int downsize;
00412 int upsize;
00414 ivec v2v_up;
00415 public:
00417 datalink () {};
00418 datalink ( const RV &rv, const RV &rv_up ) {set_connection ( rv,rv_up );};
00420 void set_connection ( const RV &rv, const RV &rv_up ) {
00421 downsize = rv._dsize();
00422 upsize = rv_up._dsize();
00423 v2v_up= ( rv.dataind ( rv_up ) );
00424
00425 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00426 }
00428 void set_connection ( int ds, int us, const ivec &upind ) {
00429 downsize = ds;
00430 upsize = us;
00431 v2v_up= upind;
00432
00433 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00434 }
00436 vec pushdown ( const vec &val_up ) {
00437 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00438 return get_vec ( val_up,v2v_up );
00439 }
00441 void pushup ( vec &val_up, const vec &val ) {
00442 it_assert_debug ( downsize==val.length(),"Wrong val" );
00443 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00444 set_subvector ( val_up, v2v_up, val );
00445 }
00446 };
00447
00449 class datalink_m2e: public datalink {
00450 protected:
00452 int condsize;
00454 ivec v2c_up;
00456 ivec v2c_lo;
00457
00458 public:
00459 datalink_m2e() {};
00461 void set_connection ( const RV &rv, const RV &rvc, const RV &rv_up ) {
00462 datalink::set_connection ( rv,rv_up );
00463 condsize= rvc._dsize();
00464
00465 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00466 }
00468 vec get_cond ( const vec &val_up ) {
00469 vec tmp ( condsize );
00470 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00471 return tmp;
00472 }
00473 void pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
00474 it_assert_debug ( downsize==val.length(),"Wrong val" );
00475 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00476 set_subvector ( val_up, v2v_up, val );
00477 set_subvector ( val_up, v2c_up, cond );
00478 }
00479 };
00482 class datalink_m2m: public datalink_m2e {
00483 protected:
00485 ivec c2c_up;
00487 ivec c2c_lo;
00488 public:
00490 datalink_m2m() {};
00491 void set_connection ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) {
00492 datalink_m2e::set_connection ( rv, rvc, rv_up );
00493
00494 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00495 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00496 }
00498 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00499 vec tmp ( condsize );
00500 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00501 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00502 return tmp;
00503 }
00505
00506 };
00507
00513 class logger : public bdmroot {
00514 protected:
00516 Array<RV> entries;
00518 Array<string> names;
00519 public:
00521 logger ( ) : entries ( 0 ),names ( 0 ) {}
00522
00525 virtual int add ( const RV &rv, string prefix="" ) {
00526 int id;
00527 if ( rv._dsize() >0 ) {
00528 id=entries.length();
00529 names=concat ( names, prefix);
00530 entries.set_length ( id+1,true );
00531 entries ( id ) = rv;
00532 }
00533 else { id =-1;}
00534 return id;
00535 }
00536
00538 virtual void logit ( int id, const vec &v ) =0;
00540 virtual void logit ( int id, const double &d ) =0;
00541
00543 virtual void step() =0;
00544
00546 virtual void finalize() {};
00547
00549 virtual void init() {};
00550
00551 };
00552
00556 class mepdf : public mpdf {
00557 public:
00559 mepdf ( epdf* em ) :mpdf ( ) {ep= em ;};
00560 mepdf (const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*>( em );};
00561 void condition ( const vec &cond ) {}
00562 };
00563
00566 class compositepdf {
00567 protected:
00569 int n;
00571 Array<mpdf*> mpdfs;
00572 public:
00573 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00575 RV getrv ( bool checkoverlap=false );
00577 void setrvc ( const RV &rv, RV &rvc );
00578 };
00579
00587 class DS : public bdmroot {
00588 protected:
00589 int dtsize;
00590 int utsize;
00592 RV Drv;
00594 RV Urv;
00596 int L_dt, L_ut;
00597 public:
00599 DS() :Drv ( ),Urv ( ) {};
00601 virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00603 virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00605 virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00607 virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00608
00610 virtual void step() =0;
00611
00613 virtual void log_add ( logger &L ) {
00614 it_assert_debug ( dtsize==Drv._dsize(),"" );
00615 it_assert_debug ( utsize==Urv._dsize(),"" );
00616
00617 L_dt=L.add ( Drv,"" );
00618 L_ut=L.add ( Urv,"" );
00619 }
00621 virtual void logit ( logger &L ) {
00622 vec tmp ( Drv._dsize() +Urv._dsize() );
00623 getdata ( tmp );
00624
00625 L.logit ( L_dt,tmp.left ( Drv._dsize() ) );
00626
00627 L.logit ( L_ut,tmp.mid ( Drv._dsize(), Urv._dsize() ) );
00628 }
00630 virtual RV _drv() const {return concat ( Drv,Urv );}
00632 const RV& _urv() const {return Urv;}
00634 virtual void set_drv (const RV &drv, const RV &urv) { Drv=drv;Urv=urv;}
00635 };
00636
00658 class BM :public bdmroot {
00659 protected:
00661 RV drv;
00663 double ll;
00665 bool evalll;
00666 public:
00669
00670 BM () :ll ( 0 ),evalll ( true ), LIDs ( 4 ), LFlags(4) {
00671 LIDs=-1; LFlags=0; LFlags(0)=1;};
00672 BM ( const BM &B ) : drv ( B.drv ), ll ( B.ll ), evalll ( B.evalll ) {}
00675 virtual BM* _copy_ () const {return NULL;};
00677
00680
00684 virtual void bayes ( const vec &dt ) = 0;
00686 virtual void bayesB ( const mat &Dt );
00689 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00691 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00692
00694 virtual epdf* epredictor ( ) const {it_error ( "Not implemented" );return NULL;};
00696 virtual mpdf* predictor ( ) const {it_error ( "Not implemented" );return NULL;};
00698
00703
00705 RV rvc;
00707 const RV& _rvc() const {return rvc;}
00708
00710 virtual void condition ( const vec &val ) {it_error ( "Not implemented!" );};
00711
00713
00714
00717
00718 const RV& _drv() const {return drv;}
00719 void set_drv ( const RV &rv ) {drv=rv;}
00720 void set_rv ( const RV &rv ) {const_cast<epdf&> ( posterior() ).set_rv ( rv );}
00721 double _ll() const {return ll;}
00722 void set_evalll ( bool evl0 ) {evalll=evl0;}
00723 virtual const epdf& posterior() const =0;
00724 virtual const epdf* _e() const =0;
00726
00729
00731 virtual void set_options ( const string &opt ) {
00732 LFlags(0)=1;
00733 if ( opt.find ( "logbounds" ) !=string::npos ) {LFlags(1)=1; LFlags(2)=1;}
00734 if ( opt.find ( "logll" ) !=string::npos ) {LFlags(3)=1;}
00735 }
00737 ivec LIDs;
00738
00740 ivec LFlags;
00742 virtual void log_add ( logger &L, const string &name="" ) {
00743
00744 RV r;
00745 if ( posterior().isnamed() ) {r=posterior()._rv();}
00746 else{r=RV ( "est", posterior().dimension() );};
00747
00748
00749 if (LFlags(0)) LIDs ( 0 ) =L.add ( r,name+"mean_" );
00750 if (LFlags(1)) LIDs ( 1 ) =L.add ( r,name+"lb_" );
00751 if (LFlags(2)) LIDs ( 2 ) =L.add ( r,name+"ub_" );
00752 if (LFlags(3)) LIDs ( 3 ) =L.add ( RV("ll",1),name );
00753 }
00754 virtual void logit ( logger &L ) {
00755 L.logit ( LIDs ( 0 ), posterior().mean() );
00756 if ( LFlags(1) || LFlags(2)) {
00757 vec ub,lb;
00758 posterior().qbounds ( lb,ub );
00759 L.logit ( LIDs ( 1 ), lb );
00760 L.logit ( LIDs ( 2 ), ub );
00761 }
00762 if (LFlags(3)) L.logit ( LIDs ( 3 ), ll );
00763 }
00765 };
00766
00767
00768 };
00769 #endif // BM_H