00001 
00017 #ifndef BM_H
00018 #define BM_H
00019 
00020 
00021 #include "../itpp_ext.h"
00022 
00023 
00024 namespace bdm {
00025         using namespace itpp;
00026         using std::string;
00027 
00029         class bdmroot {
00030                 virtual void print() {}
00031         };
00032 
00034         class str {
00035         public:
00037                 ivec ids;
00039                 ivec times;
00041                 str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00042                         it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00043                 };
00044         };
00045 
00052         class RV :public bdmroot {
00053         protected:
00055                 int tsize;
00057                 int len;
00059                 ivec ids;
00061                 ivec sizes;
00063                 ivec times;
00065                 Array<std::string> names;
00066 
00067         private:
00069                 void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00070         public:
00072                 RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00074                 RV ( Array<std::string> in_names, ivec in_sizes );
00076                 RV ( Array<std::string> in_names );
00078                 RV ();
00080                 RV (string name, int id, int sz=1, int tm=0);
00081 
00083                 friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00084 
00086                 int count() const {return tsize;} ;
00088                 int length() const {return len;} ;
00089 
00090                 
00091 
00093                 ivec findself ( const RV &rv2 ) const;
00095                 bool equal ( const RV &rv2 ) const;
00097                 bool add ( const RV &rv2 );
00099                 RV subt ( const RV &rv2 ) const;
00101                 RV subselect ( const ivec &ind ) const;
00103                 RV operator() ( const ivec &ind ) const;
00105                 void t ( int delta );
00107                 str tostr() const;
00110                 ivec dataind ( const RV &crv ) const;
00113                 void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00115                 int mint () const {return min ( times );};
00116 
00118                 Array<std::string>& _names() {return names;};
00119 
00121                 int id ( int at ) {return ids ( at );};
00123                 int size ( int at ) {return sizes ( at );};
00125                 int time ( int at ) {return times ( at );};
00127                 std::string name ( int at ) {return names ( at );};
00128 
00130                 void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00132                 void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00134                 void set_time ( int at, int time0 ) {times ( at ) =time0;};
00135 
00137                 void newids();
00138         };
00139 
00141         RV concat ( const RV &rv1, const RV &rv2 );
00142 
00144         extern RV RV0;
00145 
00147 
00148         class fnc :public bdmroot {
00149         protected:
00151                 int dimy;
00152         public:
00154                 fnc ( int dy ) :dimy ( dy ) {};
00156                 virtual vec eval ( const vec &cond ) {
00157                         return vec ( 0 );
00158                 };
00159 
00161                 virtual void condition ( const vec &val ) {};
00162 
00164                 int _dimy() const{return dimy;}
00165 
00167                 virtual ~fnc() {};
00168         };
00169 
00170         class mpdf;
00171 
00173 
00174         class epdf :public bdmroot {
00175         protected:
00177                 RV rv;
00178         public:
00180                 epdf() :rv ( ) {};
00181 
00183                 epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00184 
00185 
00186 
00187 
00189                 virtual vec sample () const =0;
00191                 virtual mat sample_m ( int N ) const;
00192 
00194                 virtual double evallog ( const vec &val ) const =0;
00195 
00197                 virtual vec evallog_m ( const mat &Val ) const {
00198                         vec x ( Val.cols() );
00199                         for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00200                         return x;
00201                 }
00203                 virtual mpdf* condition ( const RV &rv ) const  {it_warning ( "Not implemented" ); return NULL;}
00205                 virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00206 
00208                 virtual vec mean() const =0;
00209 
00211                 virtual vec variance() const = 0;
00212 
00214                 virtual ~epdf() {};
00216                 const RV& _rv() const {return rv;}
00218                 void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00220         };
00221 
00222 
00224 
00225 
00226         class mpdf : public bdmroot {
00227         protected:
00229                 RV rv;
00231                 RV rvc;
00233                 epdf* ep;
00234         public:
00235 
00237                 virtual vec samplecond ( const vec &cond) {
00238                         this->condition ( cond );
00239                         vec temp= ep->sample();
00240                         return temp;
00241                 };
00243                 virtual mat samplecond_m ( const vec &cond, vec &ll, int N ) {
00244                         this->condition ( cond );
00245                         mat temp ( rv.count(),N ); vec smp ( rv.count() );
00246                         for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evallog ( smp );}
00247                         return temp;
00248                 };
00250                 virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00251 
00253                 virtual double evallogcond ( const vec &dt, const vec &cond ) {
00254                         double tmp; this->condition ( cond );tmp = ep->evallog ( dt );          it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00255                 };
00256 
00258                 virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00259 
00261                 virtual ~mpdf() {};
00262 
00264                 mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00266                 RV _rvc() const {return rvc;}
00268                 RV _rv() const {return rv;}
00270                 epdf& _epdf() {return *ep;}
00272                 epdf* _e() {return ep;}
00273         };
00274 
00300         class datalink_e2e {
00301         protected:
00303                 int valsize;
00305                 int valupsize;
00307                 ivec v2v_up;
00308         public:
00310                 datalink_e2e ( const RV &rv, const RV &rv_up ) :
00311                                 valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) )  {
00312                         it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00313                 }
00315                 vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00317                 void fill_val ( vec &val_up, const vec &val ) {
00318                         it_assert_debug ( valsize==val.length(),"Wrong val" );
00319                         it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00320                         set_subvector ( val_up, v2v_up, val );
00321                 }
00322         };
00323 
00325         class datalink_m2e: public datalink_e2e {
00326         protected:
00328                 int condsize;
00330                 ivec v2c_up;
00332                 ivec v2c_lo;
00333 
00334         public:
00336                 datalink_m2e ( const RV &rv,  const RV &rvc, const RV &rv_up ) :
00337                                 datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00338                         
00339                         rvc.dataind ( rv_up, v2c_lo, v2c_up );
00340                 }
00342                 vec get_cond ( const vec &val_up ) {
00343                         vec tmp ( condsize );
00344                         set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00345                         return tmp;
00346                 }
00347                 void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00348                         it_assert_debug ( valsize==val.length(),"Wrong val" );
00349                         it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00350                         set_subvector ( val_up, v2v_up, val );
00351                         set_subvector ( val_up, v2c_up, cond );
00352                 }
00353         };
00356         class datalink_m2m: public datalink_m2e {
00357         protected:
00359                 ivec c2c_up;
00361                 ivec c2c_lo;
00362         public:
00364                 datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00365                                 datalink_m2e ( rv, rvc, rv_up ) {
00366                         
00367                         rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00368                         it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00369                 }
00371                 vec get_cond ( const vec &val_up, const vec &cond_up ) {
00372                         vec tmp ( condsize );
00373                         set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00374                         set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00375                         return tmp;
00376                 }
00378 
00379         };
00380 
00386         class logger : public bdmroot {
00387                 protected:
00389                         Array<RV> entries;
00391                         Array<string> names;
00392                 public:
00394                         logger ( ) : entries ( 0 ),names ( 0 ) {}
00395 
00397                         virtual int add ( const RV &rv, string name="" ) {
00398                                 int id=entries.length();
00399                                 names=concat ( names, name ); 
00400                                 entries.set_length ( id+1,true );
00401                                 entries ( id ) = rv;
00402                                 return id; 
00403                         }
00404 
00406                         virtual void logit ( int id, const vec &v ) =0;
00407 
00409                         virtual void step() =0;
00410 
00412                         virtual void finalize() {};
00413 
00415                         virtual void init() {};
00416 
00418                         virtual ~logger() {};
00419         };
00420 
00424         class mepdf : public mpdf {
00425         public:
00427                 mepdf ( const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*> ( em );};
00428                 void condition ( const vec &cond ) {}
00429         };
00430 
00433         class compositepdf {
00434         protected:
00436                 int n;
00438                 Array<mpdf*> mpdfs;
00439         public:
00440                 compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00442                 RV getrv ( bool checkoverlap=false );
00444                 void setrvc ( const RV &rv, RV &rvc );
00445         };
00446 
00454         class DS : public bdmroot {
00455         protected:
00457                 RV Drv;
00459                 RV Urv; 
00461                 int L_dt, L_ut;
00462         public:
00463                 DS() :Drv ( RV0 ),Urv ( RV0 ) {};
00464                 DS ( const RV &Drv0, const RV &Urv0 ) :Drv ( Drv0 ),Urv ( Urv0 ) {};
00466                 virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00468                 virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00470                 virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00472                 virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00473 
00475                 virtual void step() =0;
00476 
00478                 virtual void log_add ( logger &L ) {
00479                         L_dt=L.add ( Drv,"" );
00480                         L_ut=L.add ( Urv,"" );
00481                 }
00483                 virtual void logit ( logger &L ) {
00484                         vec tmp(Drv.count()+Urv.count());
00485                         getdata(tmp);
00486                         
00487                         L.logit ( L_dt,tmp.left ( Drv.count() ) );
00488                         
00489                         L.logit ( L_ut,tmp.mid ( Drv.count(), Urv.count() ) );
00490                 }
00492                 virtual RV _drv() const {return concat(Drv,Urv);}
00494                 const RV& _urv() const {return Urv;}
00495         };
00496 
00501         class BM :public bdmroot {
00502         protected:
00504                 RV rv;
00506                 RV drv;
00508                 double ll;
00510                 bool evalll;
00511         public:
00512 
00514                 BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00515                 };
00517                 BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00518 
00522                 virtual void bayes ( const vec &dt ) = 0;
00524                 virtual void bayesB ( const mat &Dt );
00526                 virtual const epdf& _epdf() const =0;
00527 
00529                 virtual const epdf* _e() const =0;
00530 
00533                 virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00535                 vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00536 
00538                 virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00539 
00541                 virtual ~BM() {};
00543                 const RV& _rv() const {return rv;}
00545                 const RV& _drv() const {return drv;}
00547                 void set_drv(const RV &rv){drv=rv;}
00549                 double _ll() const {return ll;}
00551                 void set_evalll ( bool evl0 ) {evalll=evl0;}
00552 
00555                 virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00556         };
00557 
00567         class BMcond :public bdmroot {
00568         protected:
00570                 RV rvc;
00571         public:
00573                 virtual void condition ( const vec &val ) =0;
00575                 BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00577                 virtual ~BMcond() {};
00579                 const RV& _rvc() const {return rvc;}
00580         };
00581 
00582 }; 
00584 #endif // BM_H