00001
00013 #ifndef BM_H
00014 #define BM_H
00015
00016 #include <itpp/itbase.h>
00017 #include "../itpp_ext.h"
00018
00019
00020 using namespace itpp;
00021
00023 class str {
00024 public:
00026 ivec ids;
00028 ivec times;
00030 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00031 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00032 };
00033 };
00034
00041 class RV {
00042 protected:
00044 int tsize;
00046 int len;
00048 ivec ids;
00050 ivec sizes;
00052 ivec times;
00054 Array<std::string> names;
00055
00056 private:
00058 void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00059 public:
00061 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00063 RV ( Array<std::string> in_names, ivec in_sizes );
00065 RV ( Array<std::string> in_names );
00067 RV ();
00068
00070 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00071
00073 int count() const {return tsize;} ;
00075 int length() const {return len;} ;
00076
00077
00078
00080 ivec findself ( const RV &rv2 ) const;
00082 bool equal ( const RV &rv2 ) const;
00084 bool add ( const RV &rv2 );
00086 RV subt ( const RV &rv2 ) const;
00088 RV subselect ( const ivec &ind ) const;
00090 RV operator() ( const ivec &ind ) const;
00092 void t ( int delta );
00094 str tostr() const;
00097 ivec dataind ( const RV &crv ) const;
00100 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00101
00103 Array<std::string>& _names() {return names;};
00104
00106 int id ( int at ) {return ids ( at );};
00108 int size ( int at ) {return sizes ( at );};
00110 int time ( int at ) {return times ( at );};
00112 std::string name ( int at ) {return names ( at );};
00113
00115 void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00117 void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00119 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00120
00122 void newids();
00123 };
00124
00126 RV concat ( const RV &rv1, const RV &rv2 );
00127
00128
00130
00131 class fnc {
00132 protected:
00134 int dimy;
00135 public:
00137 fnc ( int dy ) :dimy ( dy ) {};
00139 virtual vec eval ( const vec &cond ) {
00140 return vec ( 0 );
00141 };
00142
00144 int _dimy() const{return dimy;}
00145
00147 virtual ~fnc() {};
00148 };
00149
00150 class mpdf;
00151
00153
00154 class epdf {
00155 protected:
00157 RV rv;
00158 public:
00160 epdf() :rv ( ) {};
00161
00163 epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00164
00165
00166
00167
00169 virtual vec sample () const =0;
00171 virtual mat sample_m ( int N ) const;
00172
00174 virtual double evalpdflog ( const vec &val ) const =0;
00175
00177 virtual vec evalpdflog_m ( const mat &Val ) const {
00178 vec x ( Val.cols() );
00179 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evalpdflog ( Val.get_col ( i ) ) ;}
00180 return x;
00181 }
00183 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00185 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00186
00188 virtual vec mean() const =0;
00189
00191 virtual ~epdf() {};
00193 const RV& _rv() const {return rv;}
00195 void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00197 };
00198
00199
00201
00202
00203 class mpdf {
00204 protected:
00206 RV rv;
00208 RV rvc;
00210 epdf* ep;
00211 public:
00212
00214
00216 virtual vec samplecond ( const vec &cond, double &ll ) {
00217 this->condition ( cond );
00218 vec temp= ep->sample();
00219 ll=ep->evalpdflog ( temp );return temp;
00220 };
00222 virtual mat samplecond ( const vec &cond, vec &ll, int N ) {
00223 this->condition ( cond );
00224 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00225 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evalpdflog ( smp );}
00226 return temp;
00227 };
00229 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00230
00232 virtual double evalcond ( const vec &dt, const vec &cond ) {this->condition ( cond );return exp(ep->evalpdflog ( dt ));};
00233
00234 virtual vec evalcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return exp(ep->evalpdflog_m ( Dt ));};
00235
00237 virtual ~mpdf() {};
00238
00240 mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00242 RV _rvc() const {return rvc;}
00244 RV _rv() const {return rv;}
00246 epdf& _epdf() {return *ep;}
00247 };
00248
00251 class datalink_e2e {
00252 protected:
00254 int valsize;
00256 int valupsize;
00258 ivec v2v_up;
00259 public:
00261 datalink_e2e ( const RV &rv, const RV &rv_up ) :
00262 valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) ) {
00263 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00264 }
00266 vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00268 void fill_val ( vec &val_up, const vec &val ) {
00269 it_assert_debug ( valsize==val.length(),"Wrong val" );
00270 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00271 set_subvector ( val_up, v2v_up, val );
00272 }
00273 };
00274
00276 class datalink_m2e: public datalink_e2e {
00277 protected:
00279 int condsize;
00281 ivec v2c_up;
00283 ivec v2c_lo;
00284
00285 public:
00287 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00288 datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00289
00290 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00291 }
00293 vec get_cond ( const vec &val_up ) {
00294 vec tmp ( condsize );
00295 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00296 return tmp;
00297 }
00298 void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00299 it_assert_debug ( valsize==val.length(),"Wrong val" );
00300 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00301 set_subvector ( val_up, v2v_up, val );
00302 set_subvector ( val_up, v2c_up, cond );
00303 }
00304 };
00307 class datalink_m2m: public datalink_m2e {
00308 protected:
00310 ivec c2c_up;
00312 ivec c2c_lo;
00313 public:
00315 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00316 datalink_m2e ( rv, rvc, rv_up) {
00317
00318 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00319 it_assert_debug(c2c_lo.length()+v2c_lo.length()==condsize, "cond is not fully given");
00320 }
00322 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00323 vec tmp ( condsize );
00324 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00325 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00326 return tmp;
00327 }
00329
00330 };
00331
00335 class mepdf : public mpdf {
00336 public:
00338 mepdf (const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*>(em);};
00339 void condition ( const vec &cond ) {}
00340 };
00341
00344 class compositepdf {
00345 protected:
00347 int n;
00349 Array<mpdf*> mpdfs;
00350 public:
00351 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00353 RV getrv ( bool checkoverlap=false );
00355 void setrvc ( const RV &rv, RV &rvc );
00356 };
00357
00365 class DS {
00366 protected:
00368 RV Drv;
00370 RV Urv;
00371 public:
00373 void getdata ( vec &dt );
00375 void getdata ( vec &dt, ivec &indeces );
00377 void write ( vec &ut );
00379 void write ( vec &ut, ivec &indeces );
00385 void linkrvs ( RV &drv, RV &urv );
00386
00388 void step();
00389
00390 };
00391
00396 class BM {
00397 protected:
00399 RV rv;
00401 double ll;
00403 bool evalll;
00404 public:
00405
00407 BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00408 };
00410 BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00411
00415 virtual void bayes ( const vec &dt ) = 0;
00417 virtual void bayesB ( const mat &Dt );
00419 virtual const epdf& _epdf() const =0;
00420
00422 virtual const epdf* _e() const =0;
00423
00426 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00428 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00429
00431 virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00432
00434 virtual ~BM() {};
00436 const RV& _rv() const {return rv;}
00438 double _ll() const {return ll;}
00440 void set_evalll ( bool evl0 ) {evalll=evl0;}
00441
00444 virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00445 };
00446
00456 class BMcond {
00457 protected:
00459 RV rvc;
00460 public:
00462 virtual void condition ( const vec &val ) =0;
00464 BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00466 virtual ~BMcond() {};
00468 const RV& _rvc() const {return rvc;}
00469 };
00470
00471 #endif // BM_H