00001
00017 #ifndef BM_H
00018 #define BM_H
00019
00020
00021 #include "../itpp_ext.h"
00022
00023
00024 namespace bdm{
00025 using namespace itpp;
00026
00028 class bdmroot{
00029 virtual void print(){}
00030 };
00031
00033 class str {
00034 public:
00036 ivec ids;
00038 ivec times;
00040 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00041 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00042 };
00043 };
00044
00051 class RV :public bdmroot{
00052 protected:
00054 int tsize;
00056 int len;
00058 ivec ids;
00060 ivec sizes;
00062 ivec times;
00064 Array<std::string> names;
00065
00066 private:
00068 void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00069 public:
00071 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00073 RV ( Array<std::string> in_names, ivec in_sizes );
00075 RV ( Array<std::string> in_names );
00077 RV ();
00078
00080 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00081
00083 int count() const {return tsize;} ;
00085 int length() const {return len;} ;
00086
00087
00088
00090 ivec findself ( const RV &rv2 ) const;
00092 bool equal ( const RV &rv2 ) const;
00094 bool add ( const RV &rv2 );
00096 RV subt ( const RV &rv2 ) const;
00098 RV subselect ( const ivec &ind ) const;
00100 RV operator() ( const ivec &ind ) const;
00102 void t ( int delta );
00104 str tostr() const;
00107 ivec dataind ( const RV &crv ) const;
00110 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00112 int mint () {return min(times);};
00113
00115 Array<std::string>& _names() {return names;};
00116
00118 int id ( int at ) {return ids ( at );};
00120 int size ( int at ) {return sizes ( at );};
00122 int time ( int at ) {return times ( at );};
00124 std::string name ( int at ) {return names ( at );};
00125
00127 void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00129 void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00131 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00132
00134 void newids();
00135 };
00136
00138 RV concat ( const RV &rv1, const RV &rv2 );
00139
00141 extern RV RV0;
00142
00144
00145 class fnc :public bdmroot{
00146 protected:
00148 int dimy;
00149 public:
00151 fnc ( int dy ) :dimy ( dy ) {};
00153 virtual vec eval ( const vec &cond ) {
00154 return vec ( 0 );
00155 };
00156
00158 virtual void condition(const vec &val){};
00159
00161 int _dimy() const{return dimy;}
00162
00164 virtual ~fnc() {};
00165 };
00166
00167 class mpdf;
00168
00170
00171 class epdf :public bdmroot {
00172 protected:
00174 RV rv;
00175 public:
00177 epdf() :rv ( ) {};
00178
00180 epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00181
00182
00183
00184
00186 virtual vec sample () const =0;
00188 virtual mat sample_m ( int N ) const;
00189
00191 virtual double evallog ( const vec &val ) const =0;
00192
00194 virtual vec evallog_m ( const mat &Val ) const {
00195 vec x ( Val.cols() );
00196 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00197 return x;
00198 }
00200 virtual mpdf* condition ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00202 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00203
00205 virtual vec mean() const =0;
00206
00208 virtual vec variance() const = 0;
00209
00211 virtual ~epdf() {};
00213 const RV& _rv() const {return rv;}
00215 void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00217 };
00218
00219
00221
00222
00223 class mpdf : public bdmroot{
00224 protected:
00226 RV rv;
00228 RV rvc;
00230 epdf* ep;
00231 public:
00232
00234 virtual vec samplecond ( const vec &cond, double &ll ) {
00235 this->condition ( cond );
00236 vec temp= ep->sample();
00237 ll=ep->evallog ( temp );return temp;
00238 };
00240 virtual mat samplecond_m ( const vec &cond, vec &ll, int N ) {
00241 this->condition ( cond );
00242 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00243 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evallog ( smp );}
00244 return temp;
00245 };
00247 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00248
00250 virtual double evallogcond ( const vec &dt, const vec &cond ) {double tmp; this->condition ( cond );tmp = ep->evallog ( dt ); it_assert_debug(std::isfinite(tmp),"Infinite value"); return tmp;
00251 };
00252
00254 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00255
00257 virtual ~mpdf() {};
00258
00260 mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00262 RV _rvc() const {return rvc;}
00264 RV _rv() const {return rv;}
00266 epdf& _epdf() {return *ep;}
00268 epdf* _e() {return ep;}
00269 };
00270
00297 class datalink_e2e {
00298 protected:
00300 int valsize;
00302 int valupsize;
00304 ivec v2v_up;
00305 public:
00307 datalink_e2e ( const RV &rv, const RV &rv_up ) :
00308 valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) ) {
00309 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00310 }
00312 vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00314 void fill_val ( vec &val_up, const vec &val ) {
00315 it_assert_debug ( valsize==val.length(),"Wrong val" );
00316 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00317 set_subvector ( val_up, v2v_up, val );
00318 }
00319 };
00320
00322 class datalink_m2e: public datalink_e2e {
00323 protected:
00325 int condsize;
00327 ivec v2c_up;
00329 ivec v2c_lo;
00330
00331 public:
00333 datalink_m2e ( const RV &rv, const RV &rvc, const RV &rv_up ) :
00334 datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00335
00336 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00337 }
00339 vec get_cond ( const vec &val_up ) {
00340 vec tmp ( condsize );
00341 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00342 return tmp;
00343 }
00344 void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00345 it_assert_debug ( valsize==val.length(),"Wrong val" );
00346 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00347 set_subvector ( val_up, v2v_up, val );
00348 set_subvector ( val_up, v2c_up, cond );
00349 }
00350 };
00353 class datalink_m2m: public datalink_m2e {
00354 protected:
00356 ivec c2c_up;
00358 ivec c2c_lo;
00359 public:
00361 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00362 datalink_m2e ( rv, rvc, rv_up) {
00363
00364 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00365 it_assert_debug(c2c_lo.length()+v2c_lo.length()==condsize, "cond is not fully given");
00366 }
00368 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00369 vec tmp ( condsize );
00370 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00371 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00372 return tmp;
00373 }
00375
00376 };
00377
00381 class mepdf : public mpdf {
00382 public:
00384 mepdf (const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*>(em);};
00385 void condition ( const vec &cond ) {}
00386 };
00387
00390 class compositepdf {
00391 protected:
00393 int n;
00395 Array<mpdf*> mpdfs;
00396 public:
00397 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00399 RV getrv ( bool checkoverlap=false );
00401 void setrvc ( const RV &rv, RV &rvc );
00402 };
00403
00411 class DS : public bdmroot{
00412 protected:
00414 RV Drv;
00416 RV Urv;
00417 public:
00418 DS():Drv(RV0),Urv(RV0) {};
00419 DS(const RV &Drv0, const RV &Urv0):Drv(Drv0),Urv(Urv0) {};
00421 virtual void getdata ( vec &dt ){it_error("abstract class");};
00423 virtual void getdata ( vec &dt, const ivec &indeces ){it_error("abstract class");};
00425 virtual void write ( vec &ut ){it_error("abstract class");};
00427 virtual void write ( vec &ut, const ivec &indeces ){it_error("abstract class");};
00428
00430 virtual void step()=0;
00431
00432 };
00433
00438 class BM :public bdmroot{
00439 protected:
00441 RV rv;
00443 double ll;
00445 bool evalll;
00446 public:
00447
00449 BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00450 };
00452 BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00453
00457 virtual void bayes ( const vec &dt ) = 0;
00459 virtual void bayesB ( const mat &Dt );
00461 virtual const epdf& _epdf() const =0;
00462
00464 virtual const epdf* _e() const =0;
00465
00468 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00470 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00471
00473 virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00474
00476 virtual ~BM() {};
00478 const RV& _rv() const {return rv;}
00480 double _ll() const {return ll;}
00482 void set_evalll ( bool evl0 ) {evalll=evl0;}
00483
00486 virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00487 };
00488
00498 class BMcond :public bdmroot{
00499 protected:
00501 RV rvc;
00502 public:
00504 virtual void condition ( const vec &val ) =0;
00506 BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00508 virtual ~BMcond() {};
00510 const RV& _rvc() const {return rvc;}
00511 };
00512
00513 };
00515 #endif // BM_H