00001
00013 #ifndef BM_H
00014 #define BM_H
00015
00016
00017 #include "../itpp_ext.h"
00018 #include <map>
00019
00020 namespace bdm {
00021 using namespace itpp;
00022 using namespace std;
00023
00025 class bdmroot {
00026 public:
00028 virtual ~bdmroot() {}
00029 };
00030
00031 typedef std::map<string, int> RVmap;
00032 extern ivec RV_SIZES;
00033 extern Array<string> RV_NAMES;
00034
00036 class str {
00037 public:
00039 ivec ids;
00041 ivec times;
00043 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00044 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00045 };
00046 };
00047
00086 class RV :public bdmroot {
00087 protected:
00089 int dsize;
00091 int len;
00093 ivec ids;
00095 ivec times;
00096
00097 private:
00099 void init ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00100 int init ( const string &name, int size );
00101 public:
00104
00106 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times ) {init ( in_names,in_sizes,in_times );};
00108 RV ( Array<std::string> in_names, ivec in_sizes ) {init ( in_names,in_sizes,zeros_i ( in_names.length() ) );};
00110 RV ( Array<std::string> in_names ) {init ( in_names,ones_i ( in_names.length() ),zeros_i ( in_names.length() ) );}
00112 RV () :dsize ( 0 ),len ( 0 ),ids ( 0 ),times ( 0 ) {};
00114 RV ( string name, int sz, int tm=0 );
00116
00119
00121 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00122 int _dsize() const {return dsize;} ;
00124 int countsize() const;
00125 ivec cumsizes() const;
00126 int length() const {return len;} ;
00127 int id ( int at ) const{return ids ( at );};
00128 int size ( int at ) const {return RV_SIZES ( ids ( at ) );};
00129 int time ( int at ) const{return times ( at );};
00130 std::string name ( int at ) const {return RV_NAMES ( ids ( at ) );};
00131 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00133
00134
00135
00138
00140 ivec findself ( const RV &rv2 ) const;
00142 bool equal ( const RV &rv2 ) const;
00144 bool add ( const RV &rv2 );
00146 RV subt ( const RV &rv2 ) const;
00148 RV subselect ( const ivec &ind ) const;
00150 RV operator() ( const ivec &ind ) const {return subselect ( ind );};
00152 RV operator() ( int di1, int di2 ) const {
00153 ivec sz=cumsizes();
00154 int i1=0;
00155 while ( sz ( i1 ) <di1 ) i1++;
00156 int i2=i1;
00157 while ( sz ( i2 ) <di2 ) i2++;
00158 return subselect ( linspace ( i1,i2 ) );
00159 };
00161 void t ( int delta );
00163
00166
00168 str tostr() const;
00171 ivec dataind ( const RV &crv ) const;
00174 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00176 int mint () const {return min ( times );};
00178
00179 };
00180
00181
00183 RV concat ( const RV &rv1, const RV &rv2 );
00184
00186 extern RV RV0;
00187
00189
00190 class fnc :public bdmroot {
00191 protected:
00193 int dimy;
00194 public:
00196 fnc ( ) {};
00198 virtual vec eval ( const vec &cond ) {
00199 return vec ( 0 );
00200 };
00201
00203 virtual void condition ( const vec &val ) {};
00204
00206 int dimension() const{return dimy;}
00207 };
00208
00209 class mpdf;
00210
00212
00213 class epdf :public bdmroot {
00214 protected:
00216 int dim;
00218 RV rv;
00219
00220 public:
00232 epdf() :dim ( 0 ),rv ( ) {};
00233 epdf ( const epdf &e ) :dim ( e.dim ),rv ( e.rv ) {};
00234 epdf ( const RV &rv0 ) {set_rv ( rv0 );};
00235 void set_parameters ( int dim0 ) {dim=dim0;}
00237
00240
00242 virtual vec sample () const {it_error ( "not implemneted" );return vec ( 0 );};
00244 virtual mat sample_m ( int N ) const;
00246 virtual double evallog ( const vec &val ) const {it_error ( "not implemneted" );return 0.0;};
00248 virtual vec evallog_m ( const mat &Val ) const {
00249 vec x ( Val.cols() );
00250 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00251 return x;
00252 }
00254 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00256 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00258 virtual vec mean() const {it_error ( "not implemneted" );return vec ( 0 );};
00260 virtual vec variance() const {it_error ( "not implemneted" );return vec ( 0 );};
00262 virtual void qbounds ( vec &lb, vec &ub, double percentage=0.95 ) const {
00263 vec mea=mean(); vec std=sqrt(variance());
00264 lb = mea-2*std; ub=mea+2*std;
00265 };
00267
00273
00275 void set_rv ( const RV &rv0 ) {rv = rv0; }
00277 bool isnamed() const {bool b= ( dim==rv._dsize() );return b;}
00279 const RV& _rv() const {it_assert_debug ( isnamed(),"" ); return rv;}
00281
00284
00286 int dimension() const {return dim;}
00288
00289 };
00290
00291
00293
00294
00295 class mpdf : public bdmroot {
00296 protected:
00298 int dimc;
00300 RV rvc;
00302 epdf* ep;
00303 public:
00306
00307 mpdf ( ) :dimc ( 0 ),rvc ( ) {};
00309 mpdf ( const mpdf &m ) :dimc ( m.dimc ),rvc ( m.rvc ) {};
00311
00314
00316 virtual vec samplecond ( const vec &cond ) {
00317 this->condition ( cond );
00318 vec temp= ep->sample();
00319 return temp;
00320 };
00322 virtual mat samplecond_m ( const vec &cond, int N ) {
00323 this->condition ( cond );
00324 mat temp ( ep->dimension(),N ); vec smp ( ep->dimension() );
00325 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );}
00326 return temp;
00327 };
00329 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00330
00332 virtual double evallogcond ( const vec &dt, const vec &cond ) {
00333 double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00334 };
00335
00337 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00338
00341
00342 RV _rv() {return ep->_rv();}
00343 RV _rvc() {it_assert_debug ( isnamed(),"" ); return rvc;}
00344 int dimension() {return ep->dimension();}
00345 int dimensionc() {return dimc;}
00346 epdf& _epdf() {return *ep;}
00347 epdf* _e() {return ep;}
00349
00352 void set_rvc ( const RV &rvc0 ) {rvc=rvc0;}
00353 void set_rv ( const RV &rv0 ) {ep->set_rv ( rv0 );}
00354 bool isnamed() {return ( ep->isnamed() ) && ( dimc==rvc._dsize() );}
00356 };
00357
00383 class datalink {
00384 protected:
00386 int downsize;
00388 int upsize;
00390 ivec v2v_up;
00391 public:
00393 datalink () {};
00394 datalink ( const RV &rv, const RV &rv_up ) {set_connection ( rv,rv_up );};
00396 void set_connection ( const RV &rv, const RV &rv_up ) {
00397 downsize = rv._dsize();
00398 upsize = rv_up._dsize();
00399 v2v_up= ( rv.dataind ( rv_up ) );
00400
00401 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00402 }
00404 vec pushdown ( const vec &val_up ) {
00405 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00406 return get_vec ( val_up,v2v_up );
00407 }
00409 void pushup ( vec &val_up, const vec &val ) {
00410 it_assert_debug ( downsize==val.length(),"Wrong val" );
00411 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00412 set_subvector ( val_up, v2v_up, val );
00413 }
00414 };
00415
00417 class datalink_m2e: public datalink {
00418 protected:
00420 int condsize;
00422 ivec v2c_up;
00424 ivec v2c_lo;
00425
00426 public:
00428 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00429 datalink ( rv,rv_up ), condsize ( rvc._dsize() ) {
00430
00431 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00432 }
00434 vec get_cond ( const vec &val_up ) {
00435 vec tmp ( condsize );
00436 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00437 return tmp;
00438 }
00439 void pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
00440 it_assert_debug ( downsize==val.length(),"Wrong val" );
00441 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00442 set_subvector ( val_up, v2v_up, val );
00443 set_subvector ( val_up, v2c_up, cond );
00444 }
00445 };
00448 class datalink_m2m: public datalink_m2e {
00449 protected:
00451 ivec c2c_up;
00453 ivec c2c_lo;
00454 public:
00456 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00457 datalink_m2e ( rv, rvc, rv_up ) {
00458
00459 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00460 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00461 }
00463 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00464 vec tmp ( condsize );
00465 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00466 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00467 return tmp;
00468 }
00470
00471 };
00472
00478 class logger : public bdmroot {
00479 protected:
00481 Array<RV> entries;
00483 Array<string> names;
00484 public:
00486 logger ( ) : entries ( 0 ),names ( 0 ) {}
00487
00490 virtual int add ( const RV &rv, string name="" ) {
00491 int id;
00492 if ( rv._dsize() >0 ) {
00493 id=entries.length();
00494 names=concat ( names, name );
00495 entries.set_length ( id+1,true );
00496 entries ( id ) = rv;
00497 }
00498 else { id =-1;}
00499 return id;
00500 }
00501
00503 virtual void logit ( int id, const vec &v ) =0;
00504
00506 virtual void step() =0;
00507
00509 virtual void finalize() {};
00510
00512 virtual void init() {};
00513
00514 };
00515
00519 class mepdf : public mpdf {
00520 public:
00522 mepdf ( const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*> ( em );};
00523 void condition ( const vec &cond ) {}
00524 };
00525
00528 class compositepdf {
00529 protected:
00531 int n;
00533 Array<mpdf*> mpdfs;
00534 public:
00535 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00537 RV getrv ( bool checkoverlap=false );
00539 void setrvc ( const RV &rv, RV &rvc );
00540 };
00541
00549 class DS : public bdmroot {
00550 protected:
00551 int dtsize;
00552 int utsize;
00554 RV Drv;
00556 RV Urv;
00558 int L_dt, L_ut;
00559 public:
00561 DS() :Drv ( ),Urv ( ) {};
00563 virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00565 virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00567 virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00569 virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00570
00572 virtual void step() =0;
00573
00575 virtual void log_add ( logger &L ) {
00576 it_assert_debug ( dtsize==Drv._dsize(),"" );
00577 it_assert_debug ( utsize==Urv._dsize(),"" );
00578
00579 L_dt=L.add ( Drv,"" );
00580 L_ut=L.add ( Urv,"" );
00581 }
00583 virtual void logit ( logger &L ) {
00584 vec tmp ( Drv._dsize() +Urv._dsize() );
00585 getdata ( tmp );
00586
00587 L.logit ( L_dt,tmp.left ( Drv._dsize() ) );
00588
00589 L.logit ( L_ut,tmp.mid ( Drv._dsize(), Urv._dsize() ) );
00590 }
00592 virtual RV _drv() const {return concat ( Drv,Urv );}
00594 const RV& _urv() const {return Urv;}
00595 };
00596
00618 class BM :public bdmroot {
00619 protected:
00621 RV drv;
00623 double ll;
00625 bool evalll;
00626 public:
00629
00630 BM () :ll ( 0 ),evalll ( true ), LIDs ( 3 ), opt_L_bounds ( false ) {};
00631 BM ( const BM &B ) : drv ( B.drv ), ll ( B.ll ), evalll ( B.evalll ) {}
00634 virtual BM* _copy_ () const {return NULL;};
00636
00639
00643 virtual void bayes ( const vec &dt ) = 0;
00645 virtual void bayesB ( const mat &Dt );
00648 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00650 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00651
00653 virtual epdf* epredictor ( ) const {it_error ( "Not implemented" );return NULL;};
00655 virtual mpdf* predictor ( ) const {it_error ( "Not implemented" );return NULL;};
00657
00662
00664 RV rvc;
00666 const RV& _rvc() const {return rvc;}
00667
00669 virtual void condition ( const vec &val ) {it_error ( "Not implemented!" );};
00670
00672
00673
00676
00677 const RV& _drv() const {return drv;}
00678 void set_drv ( const RV &rv ) {drv=rv;}
00679 void set_rv ( const RV &rv ) {const_cast<epdf&> ( posterior() ).set_rv ( rv );}
00680 double _ll() const {return ll;}
00681 void set_evalll ( bool evl0 ) {evalll=evl0;}
00682 virtual const epdf& posterior() const =0;
00683 virtual const epdf* _e() const =0;
00685
00688
00690 void set_options ( const string &opt ) {
00691 opt_L_bounds= ( opt.find ( "logbounds" ) !=string::npos );
00692 }
00694 ivec LIDs;
00695
00697 bool opt_L_bounds;
00699 void log_add ( logger *L, const string &name="" ) {
00700
00701 RV r;
00702 if ( posterior().isnamed() ) {r=posterior()._rv();}
00703 else{r=RV ( "est", posterior().dimension() );};
00704
00705
00706 LIDs ( 0 ) =L->add ( r,name );
00707 if ( opt_L_bounds ) {
00708 LIDs ( 1 ) =L->add ( r,name+"_lb" );
00709 LIDs ( 2 ) =L->add ( r,name+"_ub" );
00710 }
00711 }
00712 void logit ( logger *L ) {
00713 L->logit ( LIDs ( 0 ), posterior().mean() );
00714 if ( opt_L_bounds ) {
00715 vec ub,lb;
00716 posterior().qbounds(lb,ub);
00717 L->logit ( LIDs ( 1 ), lb );
00718 L->logit ( LIDs ( 2 ), ub );
00719 }
00720 }
00722 };
00723
00724 };
00725 #endif // BM_H