00001
00017 #ifndef BM_H
00018 #define BM_H
00019
00020
00021 #include "../itpp_ext.h"
00022
00023
00024 namespace bdm {
00025 using namespace itpp;
00026 using std::string;
00027
00029 class bdmroot {
00030 virtual void print() {}
00031 };
00032
00034 class str {
00035 public:
00037 ivec ids;
00039 ivec times;
00041 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00042 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00043 };
00044 };
00045
00052 class RV :public bdmroot {
00053 protected:
00055 int tsize;
00057 int len;
00059 ivec ids;
00061 ivec sizes;
00063 ivec times;
00065 Array<std::string> names;
00066
00067 private:
00069 void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00070 public:
00072 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00074 RV ( Array<std::string> in_names, ivec in_sizes );
00076 RV ( Array<std::string> in_names );
00078 RV ();
00079
00081 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00082
00084 int count() const {return tsize;} ;
00086 int length() const {return len;} ;
00087
00088
00089
00091 ivec findself ( const RV &rv2 ) const;
00093 bool equal ( const RV &rv2 ) const;
00095 bool add ( const RV &rv2 );
00097 RV subt ( const RV &rv2 ) const;
00099 RV subselect ( const ivec &ind ) const;
00101 RV operator() ( const ivec &ind ) const;
00103 void t ( int delta );
00105 str tostr() const;
00108 ivec dataind ( const RV &crv ) const;
00111 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00113 int mint () {return min ( times );};
00114
00116 Array<std::string>& _names() {return names;};
00117
00119 int id ( int at ) {return ids ( at );};
00121 int size ( int at ) {return sizes ( at );};
00123 int time ( int at ) {return times ( at );};
00125 std::string name ( int at ) {return names ( at );};
00126
00128 void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00130 void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00132 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00133
00135 void newids();
00136 };
00137
00139 RV concat ( const RV &rv1, const RV &rv2 );
00140
00142 extern RV RV0;
00143
00145
00146 class fnc :public bdmroot {
00147 protected:
00149 int dimy;
00150 public:
00152 fnc ( int dy ) :dimy ( dy ) {};
00154 virtual vec eval ( const vec &cond ) {
00155 return vec ( 0 );
00156 };
00157
00159 virtual void condition ( const vec &val ) {};
00160
00162 int _dimy() const{return dimy;}
00163
00165 virtual ~fnc() {};
00166 };
00167
00168 class mpdf;
00169
00171
00172 class epdf :public bdmroot {
00173 protected:
00175 RV rv;
00176 public:
00178 epdf() :rv ( ) {};
00179
00181 epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00182
00183
00184
00185
00187 virtual vec sample () const =0;
00189 virtual mat sample_m ( int N ) const;
00190
00192 virtual double evallog ( const vec &val ) const =0;
00193
00195 virtual vec evallog_m ( const mat &Val ) const {
00196 vec x ( Val.cols() );
00197 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00198 return x;
00199 }
00201 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00203 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00204
00206 virtual vec mean() const =0;
00207
00209 virtual vec variance() const = 0;
00210
00212 virtual ~epdf() {};
00214 const RV& _rv() const {return rv;}
00216 void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00218 };
00219
00220
00222
00223
00224 class mpdf : public bdmroot {
00225 protected:
00227 RV rv;
00229 RV rvc;
00231 epdf* ep;
00232 public:
00233
00235 virtual vec samplecond ( const vec &cond, double &ll ) {
00236 this->condition ( cond );
00237 vec temp= ep->sample();
00238 ll=ep->evallog ( temp );return temp;
00239 };
00241 virtual mat samplecond_m ( const vec &cond, vec &ll, int N ) {
00242 this->condition ( cond );
00243 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00244 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evallog ( smp );}
00245 return temp;
00246 };
00248 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00249
00251 virtual double evallogcond ( const vec &dt, const vec &cond ) {
00252 double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00253 };
00254
00256 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00257
00259 virtual ~mpdf() {};
00260
00262 mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00264 RV _rvc() const {return rvc;}
00266 RV _rv() const {return rv;}
00268 epdf& _epdf() {return *ep;}
00270 epdf* _e() {return ep;}
00271 };
00272
00304 class logger : public bdmroot {
00305 protected:
00307 Array<RV> entries;
00309 Array<string> names;
00310 public:
00312 logger ( ) : entries ( 0 ),names ( 0 ) {}
00313
00315 virtual int add ( const RV &rv, string name="" ) {
00316 int id=entries.length();
00317 names=concat ( names, name );
00318 entries.set_length ( id+1,true );
00319 entries ( id ) = rv;
00320 return id;
00321 }
00322
00324 virtual void logit ( int id, const vec &v ) =0;
00325
00327 virtual void step() =0;
00328
00330 virtual void finalize() {};
00331
00333 virtual void init() {};
00334
00336 virtual ~logger() {};
00337 };
00338
00339 class datalink_e2e {
00340 protected:
00342 int valsize;
00344 int valupsize;
00346 ivec v2v_up;
00347 public:
00349 datalink_e2e ( const RV &rv, const RV &rv_up ) :
00350 valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) ) {
00351 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00352 }
00354 vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00356 void fill_val ( vec &val_up, const vec &val ) {
00357 it_assert_debug ( valsize==val.length(),"Wrong val" );
00358 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00359 set_subvector ( val_up, v2v_up, val );
00360 }
00361 };
00362
00364 class datalink_m2e: public datalink_e2e {
00365 protected:
00367 int condsize;
00369 ivec v2c_up;
00371 ivec v2c_lo;
00372
00373 public:
00375 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00376 datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00377
00378 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00379 }
00381 vec get_cond ( const vec &val_up ) {
00382 vec tmp ( condsize );
00383 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00384 return tmp;
00385 }
00386 void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00387 it_assert_debug ( valsize==val.length(),"Wrong val" );
00388 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00389 set_subvector ( val_up, v2v_up, val );
00390 set_subvector ( val_up, v2c_up, cond );
00391 }
00392 };
00395 class datalink_m2m: public datalink_m2e {
00396 protected:
00398 ivec c2c_up;
00400 ivec c2c_lo;
00401 public:
00403 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00404 datalink_m2e ( rv, rvc, rv_up ) {
00405
00406 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00407 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00408 }
00410 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00411 vec tmp ( condsize );
00412 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00413 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00414 return tmp;
00415 }
00417
00418 };
00419
00423 class mepdf : public mpdf {
00424 public:
00426 mepdf ( const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*> ( em );};
00427 void condition ( const vec &cond ) {}
00428 };
00429
00432 class compositepdf {
00433 protected:
00435 int n;
00437 Array<mpdf*> mpdfs;
00438 public:
00439 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00441 RV getrv ( bool checkoverlap=false );
00443 void setrvc ( const RV &rv, RV &rvc );
00444 };
00445
00453 class DS : public bdmroot {
00454 protected:
00456 RV Drv;
00458 RV Urv;
00460 int L_dt, L_ut;
00461 public:
00462 DS() :Drv ( RV0 ),Urv ( RV0 ) {};
00463 DS ( const RV &Drv0, const RV &Urv0 ) :Drv ( Drv0 ),Urv ( Urv0 ) {};
00465 virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00467 virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00469 virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00471 virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00472
00474 virtual void step() =0;
00475
00477 virtual void log_add ( logger &L ) {
00478 L_dt=L.add ( Drv,"" );
00479 L_ut=L.add ( Urv,"" );
00480 }
00482 virtual void logit ( logger &L ) {
00483 vec tmp(Drv.count()+Urv.count());
00484 getdata(tmp);
00485
00486 L.logit ( L_dt,tmp.left ( Drv.count() ) );
00487
00488 L.logit ( L_ut,tmp.mid ( Drv.count(), Urv.count() ) );
00489 }
00490 };
00491
00496 class BM :public bdmroot {
00497 protected:
00499 RV rv;
00501 double ll;
00503 bool evalll;
00504 public:
00505
00507 BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00508 };
00510 BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00511
00515 virtual void bayes ( const vec &dt ) = 0;
00517 virtual void bayesB ( const mat &Dt );
00519 virtual const epdf& _epdf() const =0;
00520
00522 virtual const epdf* _e() const =0;
00523
00526 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00528 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00529
00531 virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00532
00534 virtual ~BM() {};
00536 const RV& _rv() const {return rv;}
00538 double _ll() const {return ll;}
00540 void set_evalll ( bool evl0 ) {evalll=evl0;}
00541
00544 virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00545 };
00546
00556 class BMcond :public bdmroot {
00557 protected:
00559 RV rvc;
00560 public:
00562 virtual void condition ( const vec &val ) =0;
00564 BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00566 virtual ~BMcond() {};
00568 const RV& _rvc() const {return rvc;}
00569 };
00570
00571 };
00573 #endif // BM_H