00001
00017 #ifndef BM_H
00018 #define BM_H
00019
00020 #include <itpp/itbase.h>
00021 #include "../itpp_ext.h"
00022
00023
00024 using namespace itpp;
00025
00027 class str {
00028 public:
00030 ivec ids;
00032 ivec times;
00034 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00035 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00036 };
00037 };
00038
00045 class RV {
00046 protected:
00048 int tsize;
00050 int len;
00052 ivec ids;
00054 ivec sizes;
00056 ivec times;
00058 Array<std::string> names;
00059
00060 private:
00062 void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00063 public:
00065 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00067 RV ( Array<std::string> in_names, ivec in_sizes );
00069 RV ( Array<std::string> in_names );
00071 RV ();
00072
00074 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00075
00077 int count() const {return tsize;} ;
00079 int length() const {return len;} ;
00080
00081
00082
00084 ivec findself ( const RV &rv2 ) const;
00086 bool equal ( const RV &rv2 ) const;
00088 bool add ( const RV &rv2 );
00090 RV subt ( const RV &rv2 ) const;
00092 RV subselect ( const ivec &ind ) const;
00094 RV operator() ( const ivec &ind ) const;
00096 void t ( int delta );
00098 str tostr() const;
00101 ivec dataind ( const RV &crv ) const;
00104 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00105
00107 Array<std::string>& _names() {return names;};
00108
00110 int id ( int at ) {return ids ( at );};
00112 int size ( int at ) {return sizes ( at );};
00114 int time ( int at ) {return times ( at );};
00116 std::string name ( int at ) {return names ( at );};
00117
00119 void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00121 void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00123 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00124
00126 void newids();
00127 };
00128
00130 RV concat ( const RV &rv1, const RV &rv2 );
00131
00133 extern RV RV0;
00134
00136
00137 class fnc {
00138 protected:
00140 int dimy;
00141 public:
00143 fnc ( int dy ) :dimy ( dy ) {};
00145 virtual vec eval ( const vec &cond ) {
00146 return vec ( 0 );
00147 };
00148
00150 virtual void condition(const vec &val){};
00151
00153 int _dimy() const{return dimy;}
00154
00156 virtual ~fnc() {};
00157 };
00158
00159 class mpdf;
00160
00162
00163 class epdf {
00164 protected:
00166 RV rv;
00167 public:
00169 epdf() :rv ( ) {};
00170
00172 epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00173
00174
00175
00176
00178 virtual vec sample () const =0;
00180 virtual mat sample_m ( int N ) const;
00181
00183 virtual double evallog ( const vec &val ) const =0;
00184
00186 virtual vec evallog_m ( const mat &Val ) const {
00187 vec x ( Val.cols() );
00188 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00189 return x;
00190 }
00192 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00194 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00195
00197 virtual vec mean() const =0;
00198
00200 virtual ~epdf() {};
00202 const RV& _rv() const {return rv;}
00204 void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00206 };
00207
00208
00210
00211
00212 class mpdf {
00213 protected:
00215 RV rv;
00217 RV rvc;
00219 epdf* ep;
00220 public:
00221
00223
00225 virtual vec samplecond ( const vec &cond, double &ll ) {
00226 this->condition ( cond );
00227 vec temp= ep->sample();
00228 ll=ep->evallog ( temp );return temp;
00229 };
00231 virtual mat samplecond_m ( const vec &cond, vec &ll, int N ) {
00232 this->condition ( cond );
00233 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00234 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evallog ( smp );}
00235 return temp;
00236 };
00238 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00239
00241 virtual double evallogcond ( const vec &dt, const vec &cond ) {double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug(std::isfinite(tmp),"Infinite value"); return tmp;
00242 };
00243
00245 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00246
00248 virtual ~mpdf() {};
00249
00251 mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00253 RV _rvc() const {return rvc;}
00255 RV _rv() const {return rv;}
00257 epdf& _epdf() {return *ep;}
00258 };
00259
00262 class datalink_e2e {
00263 protected:
00265 int valsize;
00267 int valupsize;
00269 ivec v2v_up;
00270 public:
00272 datalink_e2e ( const RV &rv, const RV &rv_up ) :
00273 valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) ) {
00274 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00275 }
00277 vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00279 void fill_val ( vec &val_up, const vec &val ) {
00280 it_assert_debug ( valsize==val.length(),"Wrong val" );
00281 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00282 set_subvector ( val_up, v2v_up, val );
00283 }
00284 };
00285
00287 class datalink_m2e: public datalink_e2e {
00288 protected:
00290 int condsize;
00292 ivec v2c_up;
00294 ivec v2c_lo;
00295
00296 public:
00298 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00299 datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00300
00301 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00302 }
00304 vec get_cond ( const vec &val_up ) {
00305 vec tmp ( condsize );
00306 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00307 return tmp;
00308 }
00309 void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00310 it_assert_debug ( valsize==val.length(),"Wrong val" );
00311 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00312 set_subvector ( val_up, v2v_up, val );
00313 set_subvector ( val_up, v2c_up, cond );
00314 }
00315 };
00318 class datalink_m2m: public datalink_m2e {
00319 protected:
00321 ivec c2c_up;
00323 ivec c2c_lo;
00324 public:
00326 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00327 datalink_m2e ( rv, rvc, rv_up) {
00328
00329 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00330 it_assert_debug(c2c_lo.length()+v2c_lo.length()==condsize, "cond is not fully given");
00331 }
00333 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00334 vec tmp ( condsize );
00335 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00336 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00337 return tmp;
00338 }
00340
00341 };
00342
00346 class mepdf : public mpdf {
00347 public:
00349 mepdf (const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*>(em);};
00350 void condition ( const vec &cond ) {}
00351 };
00352
00355 class compositepdf {
00356 protected:
00358 int n;
00360 Array<mpdf*> mpdfs;
00361 public:
00362 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00364 RV getrv ( bool checkoverlap=false );
00366 void setrvc ( const RV &rv, RV &rvc );
00367 };
00368
00376 class DS {
00377 protected:
00379 RV Drv;
00381 RV Urv;
00382 public:
00384 void getdata ( vec &dt );
00386 void getdata ( vec &dt, ivec &indeces );
00388 void write ( vec &ut );
00390 void write ( vec &ut, ivec &indeces );
00396 void linkrvs ( RV &drv, RV &urv );
00397
00399 void step();
00400
00401 };
00402
00407 class BM {
00408 protected:
00410 RV rv;
00412 double ll;
00414 bool evalll;
00415 public:
00416
00418 BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00419 };
00421 BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00422
00426 virtual void bayes ( const vec &dt ) = 0;
00428 virtual void bayesB ( const mat &Dt );
00430 virtual const epdf& _epdf() const =0;
00431
00433 virtual const epdf* _e() const =0;
00434
00437 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00439 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00440
00442 virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00443
00445 virtual ~BM() {};
00447 const RV& _rv() const {return rv;}
00449 double _ll() const {return ll;}
00451 void set_evalll ( bool evl0 ) {evalll=evl0;}
00452
00455 virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00456 };
00457
00467 class BMcond {
00468 protected:
00470 RV rvc;
00471 public:
00473 virtual void condition ( const vec &val ) =0;
00475 BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00477 virtual ~BMcond() {};
00479 const RV& _rvc() const {return rvc;}
00480 };
00481
00483 #endif // BM_H