00001
00019 #ifndef BM_H
00020 #define BM_H
00021
00022
00023 #include "../itpp_ext.h"
00024 #include <map>
00025
00026 namespace bdm {
00027 using namespace itpp;
00028 using namespace std;
00029
00031 class bdmroot {
00032 public:
00034 virtual ~bdmroot() {}
00035 };
00036
00037 typedef std::map<string, int> RVmap;
00038 extern ivec RV_SIZES;
00039 extern Array<string> RV_NAMES;
00040
00042 class str {
00043 public:
00045 ivec ids;
00047 ivec times;
00049 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00050 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00051 };
00052 };
00053
00092 class RV :public bdmroot {
00093 protected:
00095 int dsize;
00097 int len;
00099 ivec ids;
00101 ivec times;
00102
00103 private:
00105 void init ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00106 int init ( const string &name, int size );
00107 public:
00110
00112 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times ) {init ( in_names,in_sizes,in_times );};
00114 RV ( Array<std::string> in_names, ivec in_sizes ) {init ( in_names,in_sizes,zeros_i ( in_names.length() ) );};
00116 RV ( Array<std::string> in_names ) {init ( in_names,ones_i ( in_names.length() ),zeros_i ( in_names.length() ) );}
00118 RV () :dsize ( 0 ),len ( 0 ),ids ( 0 ),times ( 0 ) {};
00120 RV ( string name, int sz, int tm=0 );
00122
00125
00127 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00128 int _dsize() const {return dsize;} ;
00130 int countsize() const;
00131 int length() const {return len;} ;
00132 int id ( int at ) const{return ids ( at );};
00133 int size ( int at ) const {return RV_SIZES ( at );};
00134 int time ( int at ) const{return times ( at );};
00135 std::string name ( int at ) const {return RV_NAMES ( at );};
00136 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00138
00139
00140
00143
00145 ivec findself ( const RV &rv2 ) const;
00147 bool equal ( const RV &rv2 ) const;
00149 bool add ( const RV &rv2 );
00151 RV subt ( const RV &rv2 ) const;
00153 RV subselect ( const ivec &ind ) const;
00155 RV operator() ( const ivec &ind ) const {return subselect ( ind );};
00157 void t ( int delta );
00159
00162
00164 str tostr() const;
00167 ivec dataind ( const RV &crv ) const;
00170 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00172 int mint () const {return min ( times );};
00174
00175 };
00176
00177
00179 RV concat ( const RV &rv1, const RV &rv2 );
00180
00182 extern RV RV0;
00183
00185
00186 class fnc :public bdmroot {
00187 protected:
00189 int dimy;
00190 public:
00192 fnc ( ) {};
00194 virtual vec eval ( const vec &cond ) {
00195 return vec ( 0 );
00196 };
00197
00199 virtual void condition ( const vec &val ) {};
00200
00202 int _dimy() const{return dimy;}
00203 };
00204
00205 class mpdf;
00206
00208
00209 class epdf :public bdmroot {
00210 protected:
00212 int dim;
00214 RV rv;
00215
00216 public:
00225 epdf() :dim(0),rv ( ) {};
00226 epdf(const epdf &e) :dim(e.dim),rv (e.rv) {};
00228
00231
00233 virtual vec sample () const =0;
00235 virtual mat sample_m ( int N ) const;
00237 virtual double evallog ( const vec &val ) const =0;
00239 virtual vec evallog_m ( const mat &Val ) const {
00240 vec x ( Val.cols() );
00241 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00242 return x;
00243 }
00245 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00247 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00249 virtual vec mean() const =0;
00251 virtual vec variance() const = 0;
00253
00259
00261 void set_rv ( const RV &rv0 ) {rv = rv0; it_assert_debug(isnamed(),""); };
00263 bool isnamed() const {return ( dim==rv._dsize() );}
00265 const RV& _rv() const {it_assert_debug ( isnamed(),"" ); return rv;}
00267
00270
00272 bool dimension() const {return dim;}
00274
00275 };
00276
00277
00279
00280
00281 class mpdf : public bdmroot {
00282 protected:
00284 int dimc;
00286 RV rvc;
00288 epdf* ep;
00289 public:
00292
00293 mpdf ( ) :dimc(0),rvc ( ) {};
00295 mpdf (const mpdf &m ) :dimc(m.dimc),rvc (m.rvc ) {};
00297
00300
00302 virtual vec samplecond ( const vec &cond ) {
00303 this->condition ( cond );
00304 vec temp= ep->sample();
00305 return temp;
00306 };
00308 virtual mat samplecond_m ( const vec &cond, int N ) {
00309 this->condition ( cond );
00310 mat temp ( ep->dimension(),N ); vec smp ( ep->dimension() );
00311 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );}
00312 return temp;
00313 };
00315 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00316
00318 virtual double evallogcond ( const vec &dt, const vec &cond ) {
00319 double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00320 };
00321
00323 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00324
00327
00328 RV _rv() {return ep->_rv();}
00329 RV _rvc() {it_assert_debug ( isnamed(),"" ); return rvc;}
00330 int dimension() {return ep->dimension();}
00331 int dimensionc() {return dimc;}
00332 epdf& _epdf() {return *ep;}
00333 epdf* _e() {return ep;}
00335
00338 void set_rvc ( const RV &rvc0 ) {rvc=rvc0;}
00339 void set_rv ( const RV &rv0 ) {ep->set_rv(rv0);}
00340 bool isnamed() {return (ep->isnamed())&&(dimc=rvc._dsize());}
00342 };
00343
00369 class datalink {
00370 protected:
00372 int downsize;
00374 int upsize;
00376 ivec v2v_up;
00377 public:
00379 datalink ( const RV &rv, const RV &rv_up ) :
00380 downsize ( rv._dsize() ), upsize ( rv_up._dsize() ), v2v_up ( rv.dataind ( rv_up ) ) {
00381 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00382 }
00384 vec pushdown ( const vec &val_up ) {
00385 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00386 return get_vec ( val_up,v2v_up );
00387 }
00389 void pushup ( vec &val_up, const vec &val ) {
00390 it_assert_debug ( downsize==val.length(),"Wrong val" );
00391 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00392 set_subvector ( val_up, v2v_up, val );
00393 }
00394 };
00395
00397 class datalink_m2e: public datalink {
00398 protected:
00400 int condsize;
00402 ivec v2c_up;
00404 ivec v2c_lo;
00405
00406 public:
00408 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00409 datalink ( rv,rv_up ), condsize ( rvc._dsize() ) {
00410
00411 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00412 }
00414 vec get_cond ( const vec &val_up ) {
00415 vec tmp ( condsize );
00416 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00417 return tmp;
00418 }
00419 void pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
00420 it_assert_debug ( downsize==val.length(),"Wrong val" );
00421 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00422 set_subvector ( val_up, v2v_up, val );
00423 set_subvector ( val_up, v2c_up, cond );
00424 }
00425 };
00428 class datalink_m2m: public datalink_m2e {
00429 protected:
00431 ivec c2c_up;
00433 ivec c2c_lo;
00434 public:
00436 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00437 datalink_m2e ( rv, rvc, rv_up ) {
00438
00439 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00440 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00441 }
00443 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00444 vec tmp ( condsize );
00445 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00446 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00447 return tmp;
00448 }
00450
00451 };
00452
00458 class logger : public bdmroot {
00459 protected:
00461 Array<RV> entries;
00463 Array<string> names;
00464 public:
00466 logger ( ) : entries ( 0 ),names ( 0 ) {}
00467
00469 virtual int add ( const RV &rv, string name="" ) {
00470 int id=entries.length();
00471 names=concat ( names, name );
00472 entries.set_length ( id+1,true );
00473 entries ( id ) = rv;
00474 return id;
00475 }
00476
00478 virtual void logit ( int id, const vec &v ) =0;
00479
00481 virtual void step() =0;
00482
00484 virtual void finalize() {};
00485
00487 virtual void init() {};
00488
00489 };
00490
00494 class mepdf : public mpdf {
00495 public:
00497 mepdf ( const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*> ( em );};
00498 void condition ( const vec &cond ) {}
00499 };
00500
00503 class compositepdf {
00504 protected:
00506 int n;
00508 Array<mpdf*> mpdfs;
00509 public:
00510 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00512 RV getrv ( bool checkoverlap=false );
00514 void setrvc ( const RV &rv, RV &rvc );
00515 };
00516
00524 class DS : public bdmroot {
00525 protected:
00526 int dtsize;
00527 int utsize;
00529 RV Drv;
00531 RV Urv;
00533 int L_dt, L_ut;
00534 public:
00536 DS() :Drv ( ),Urv ( ) {};
00538 virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00540 virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00542 virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00544 virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00545
00547 virtual void step() =0;
00548
00550 virtual void log_add ( logger &L ) {
00551 it_assert_debug ( dtsize==Drv._dsize(),"" );
00552 it_assert_debug ( utsize==Urv._dsize(),"" );
00553
00554 L_dt=L.add ( Drv,"" );
00555 L_ut=L.add ( Urv,"" );
00556 }
00558 virtual void logit ( logger &L ) {
00559 vec tmp ( Drv._dsize() +Urv._dsize() );
00560 getdata ( tmp );
00561
00562 L.logit ( L_dt,tmp.left ( Drv._dsize() ) );
00563
00564 L.logit ( L_ut,tmp.mid ( Drv._dsize(), Urv._dsize() ) );
00565 }
00567 virtual RV _drv() const {return concat ( Drv,Urv );}
00569 const RV& _urv() const {return Urv;}
00570 };
00571
00576 class BM :public bdmroot {
00577 protected:
00579 RV drv;
00581 double ll;
00583 bool evalll;
00584 public:
00587
00588 BM () :ll (0),evalll ( false) {};
00589 BM ( const BM &B ) : drv(B.drv), ll ( B.ll ), evalll ( B.evalll ) {}
00592 virtual BM* _copy_ () {return NULL;};
00594
00597
00601 virtual void bayes ( const vec &dt ) = 0;
00603 virtual void bayesB ( const mat &Dt );
00606 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00608 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00609
00611 virtual epdf* epredictor ( ) const {it_error ( "Not implemented" );return NULL;};
00613 virtual mpdf* predictor ( ) const {it_error ( "Not implemented" );return NULL;};
00615
00618
00619 const RV& _drv() const {return drv;}
00620 void set_drv ( const RV &rv ) {drv=rv;}
00621 double _ll() const {return ll;}
00622 void set_evalll ( bool evl0 ) {evalll=evl0;}
00623 virtual const epdf& _epdf() const =0;
00624 virtual const epdf* _e() const =0;
00626
00627 };
00628
00638 class BMcond :public bdmroot {
00639 protected:
00641 RV rvc;
00642 public:
00644 virtual void condition ( const vec &val ) =0;
00646 BMcond ( ) :rvc ( ) {};
00648 virtual ~BMcond() {};
00650 const RV& _rvc() const {return rvc;}
00651 };
00652
00653 };
00655 #endif // BM_H