00001 
00013 #ifndef BM_H
00014 #define BM_H
00015 
00016 
00017 #include "../itpp_ext.h"
00018 #include <map>
00019 
00020 namespace bdm {
00021 using namespace itpp;
00022 using namespace std;
00023 
00025 class bdmroot {
00026 public:
00028         virtual ~bdmroot() {}
00029 };
00030 
00031 typedef std::map<string, int> RVmap;
00032 extern ivec RV_SIZES;
00033 extern Array<string> RV_NAMES;
00034 
00036 class str {
00037 public:
00039         ivec ids;
00041         ivec times;
00043         str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00044                 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00045         };
00046 };
00047 
00086 class RV :public bdmroot {
00087 protected:
00089         int dsize;
00091         int len;
00093         ivec ids;
00095         ivec times;
00096 
00097 private:
00099         void init ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00100         int init ( const  string &name, int size );
00101 public:
00104 
00106         RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times ) {init ( in_names,in_sizes,in_times );};
00108         RV ( Array<std::string> in_names, ivec in_sizes ) {init ( in_names,in_sizes,zeros_i ( in_names.length() ) );};
00110         RV ( Array<std::string> in_names ) {init ( in_names,ones_i ( in_names.length() ),zeros_i ( in_names.length() ) );}
00112         RV () :dsize ( 0 ),len ( 0 ),ids ( 0 ),times ( 0 ) {};
00114         RV ( string name, int sz, int tm=0 );
00116 
00119 
00121         friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00122         int _dsize() const {return dsize;} ;
00124         int countsize() const;
00125         ivec cumsizes() const;
00126         int length() const {return len;} ;
00127         int id ( int at ) const{return ids ( at );};
00128         int size ( int at ) const {return RV_SIZES ( ids ( at ) );};
00129         int time ( int at ) const{return times ( at );};
00130         std::string name ( int at ) const {return RV_NAMES ( ids ( at ) );};
00131         void set_time ( int at, int time0 ) {times ( at ) =time0;};
00133 
00134         
00135 
00138 
00140         ivec findself ( const RV &rv2 ) const;
00142         bool equal ( const RV &rv2 ) const;
00144         bool add ( const RV &rv2 );
00146         RV subt ( const RV &rv2 ) const;
00148         RV subselect ( const ivec &ind ) const;
00150         RV operator() ( const ivec &ind ) const {return subselect ( ind );};
00152         RV operator() ( int di1, int di2 ) const {
00153                 ivec sz=cumsizes();
00154                 int i1=0;
00155                 while ( sz ( i1 ) <di1 ) i1++;
00156                 int i2=i1;
00157                 while ( sz ( i2 ) <di2 ) i2++;
00158                 return subselect ( linspace ( i1,i2 ) );
00159         };
00161         void t ( int delta );
00163 
00166 
00168         str tostr() const;
00171         ivec dataind ( const RV &crv ) const;
00174         void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00176         int mint () const {return min ( times );};
00178 
00179 };
00180 
00181 
00183 RV concat ( const RV &rv1, const RV &rv2 );
00184 
00186 extern RV RV0;
00187 
00189 
00190 class fnc :public bdmroot {
00191 protected:
00193         int dimy;
00194 public:
00196         fnc ( ) {};
00198         virtual vec eval ( const vec &cond ) {
00199                 return vec ( 0 );
00200         };
00201 
00203         virtual void condition ( const vec &val ) {};
00204 
00206         int dimension() const{return dimy;}
00207 };
00208 
00209 class mpdf;
00210 
00212 
00213 class epdf :public bdmroot {
00214 protected:
00216         int dim;
00218         RV rv;
00219 
00220 public:
00232         epdf() :dim ( 0 ),rv ( ) {};
00233         epdf ( const epdf &e ) :dim ( e.dim ),rv ( e.rv ) {};
00234         epdf ( const RV &rv0 ) {set_rv ( rv0 );};
00235         void set_parameters ( int dim0 ) {dim=dim0;}
00237 
00240 
00242         virtual vec sample () const {it_error ( "not implemneted" );return vec ( 0 );};
00244         virtual mat sample_m ( int N ) const;
00246         virtual double evallog ( const vec &val ) const {it_error ( "not implemneted" );return 0.0;};
00248         virtual vec evallog_m ( const mat &Val ) const {
00249                 vec x ( Val.cols() );
00250                 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00251                 return x;
00252         }
00254         virtual mpdf* condition ( const RV &rv ) const  {it_warning ( "Not implemented" ); return NULL;}
00256         virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00258         virtual vec mean() const {it_error ( "not implemneted" );return vec ( 0 );};
00260         virtual vec variance() const {it_error ( "not implemneted" );return vec ( 0 );};
00262         virtual void qbounds ( vec &lb, vec &ub, double percentage=0.95 ) const {
00263                 vec mea=mean(); vec std=sqrt ( variance() );
00264                 lb = mea-2*std; ub=mea+2*std;
00265         };
00267 
00273 
00275         void set_rv ( const RV &rv0 ) {rv = rv0; }
00277         bool isnamed() const {bool b= ( dim==rv._dsize() );return b;}
00279         const RV& _rv() const {it_assert_debug ( isnamed(),"" ); return rv;}
00281 
00284 
00286         int dimension() const {return dim;}
00288 
00289 };
00290 
00291 
00293 
00294 
00295 class mpdf : public bdmroot {
00296 protected:
00298         int dimc;
00300         RV rvc;
00302         epdf* ep;
00303 public:
00306 
00307         mpdf ( ) :dimc ( 0 ),rvc ( ) {};
00309         mpdf ( const mpdf &m ) :dimc ( m.dimc ),rvc ( m.rvc ) {};
00311 
00314 
00316         virtual vec samplecond ( const vec &cond ) {
00317                 this->condition ( cond );
00318                 vec temp= ep->sample();
00319                 return temp;
00320         };
00322         virtual mat samplecond_m ( const vec &cond, int N ) {
00323                 this->condition ( cond );
00324                 mat temp ( ep->dimension(),N ); vec smp ( ep->dimension() );
00325                 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );}
00326                 return temp;
00327         };
00329         virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00330 
00332         virtual double evallogcond ( const vec &dt, const vec &cond ) {
00333                 double tmp; this->condition ( cond );tmp = ep->evallog ( dt );          it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00334         };
00335 
00337         virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00338 
00341 
00342         RV _rv() {return ep->_rv();}
00343         RV _rvc() {it_assert_debug ( isnamed(),"" ); return rvc;}
00344         int dimension() {return ep->dimension();}
00345         int dimensionc() {return dimc;}
00346         epdf& _epdf() {return *ep;}
00347         epdf* _e() {return ep;}
00349 
00352         void set_rvc ( const RV &rvc0 ) {rvc=rvc0;}
00353         void set_rv ( const RV &rv0 ) {ep->set_rv ( rv0 );}
00354         bool isnamed() {return ( ep->isnamed() ) && ( dimc==rvc._dsize() );}
00356 };
00357 
00383 class datalink {
00384 protected:
00386         int downsize;
00388         int upsize;
00390         ivec v2v_up;
00391 public:
00393         datalink () {};
00394         datalink ( const RV &rv, const RV &rv_up ) {set_connection ( rv,rv_up );};
00396         void set_connection ( const RV &rv, const RV &rv_up ) {
00397                 downsize = rv._dsize();
00398                 upsize = rv_up._dsize();
00399                 v2v_up= ( rv.dataind ( rv_up ) );
00400 
00401                 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00402         }
00404         void set_connection ( int ds, int us, const ivec &upind ) {
00405                 downsize = ds;
00406                 upsize = us;
00407                 v2v_up= upind;
00408 
00409                 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00410         }
00412         vec pushdown ( const vec &val_up ) {
00413                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00414                 return get_vec ( val_up,v2v_up );
00415         }
00417         void pushup ( vec &val_up, const vec &val ) {
00418                 it_assert_debug ( downsize==val.length(),"Wrong val" );
00419                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00420                 set_subvector ( val_up, v2v_up, val );
00421         }
00422 };
00423 
00425 class datalink_m2e: public datalink {
00426 protected:
00428         int condsize;
00430         ivec v2c_up;
00432         ivec v2c_lo;
00433 
00434 public:
00435         datalink_m2e() {};
00437         void set_connection ( const RV &rv,  const RV &rvc, const RV &rv_up ) {
00438                 datalink::set_connection ( rv,rv_up );
00439                 condsize=  rvc._dsize();
00440                 
00441                 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00442         }
00444         vec get_cond ( const vec &val_up ) {
00445                 vec tmp ( condsize );
00446                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00447                 return tmp;
00448         }
00449         void pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
00450                 it_assert_debug ( downsize==val.length(),"Wrong val" );
00451                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00452                 set_subvector ( val_up, v2v_up, val );
00453                 set_subvector ( val_up, v2c_up, cond );
00454         }
00455 };
00458 class datalink_m2m: public datalink_m2e {
00459 protected:
00461         ivec c2c_up;
00463         ivec c2c_lo;
00464 public:
00466         datalink_m2m() {};
00467         void set_connection ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) {
00468                 datalink_m2e::set_connection ( rv, rvc, rv_up );
00469                 
00470                 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00471                 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00472         }
00474         vec get_cond ( const vec &val_up, const vec &cond_up ) {
00475                 vec tmp ( condsize );
00476                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00477                 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00478                 return tmp;
00479         }
00481 
00482 };
00483 
00489 class logger : public bdmroot {
00490 protected:
00492         Array<RV> entries;
00494         Array<string> names;
00495 public:
00497         logger ( ) : entries ( 0 ),names ( 0 ) {}
00498 
00501         virtual int add ( const RV &rv, string name="" ) {
00502                 int id;
00503                 if ( rv._dsize() >0 ) {
00504                         id=entries.length();
00505                         names=concat ( names, name ); 
00506                         entries.set_length ( id+1,true );
00507                         entries ( id ) = rv;
00508                 }
00509                 else { id =-1;}
00510                 return id; 
00511         }
00512 
00514         virtual void logit ( int id, const vec &v ) =0;
00515 
00517         virtual void step() =0;
00518 
00520         virtual void finalize() {};
00521 
00523         virtual void init() {};
00524 
00525 };
00526 
00530 class mepdf : public mpdf {
00531 public:
00533         mepdf ( epdf* em ) :mpdf ( ) {ep= em ;};
00534         mepdf (const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*>( em );};
00535         void condition ( const vec &cond ) {}
00536 };
00537 
00540 class compositepdf {
00541 protected:
00543         int n;
00545         Array<mpdf*> mpdfs;
00546 public:
00547         compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00549         RV getrv ( bool checkoverlap=false );
00551         void setrvc ( const RV &rv, RV &rvc );
00552 };
00553 
00561 class DS : public bdmroot {
00562 protected:
00563         int dtsize;
00564         int utsize;
00566         RV Drv;
00568         RV Urv; 
00570         int L_dt, L_ut;
00571 public:
00573         DS() :Drv ( ),Urv ( ) {};
00575         virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00577         virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00579         virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00581         virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00582 
00584         virtual void step() =0;
00585 
00587         virtual void log_add ( logger &L ) {
00588                 it_assert_debug ( dtsize==Drv._dsize(),"" );
00589                 it_assert_debug ( utsize==Urv._dsize(),"" );
00590 
00591                 L_dt=L.add ( Drv,"" );
00592                 L_ut=L.add ( Urv,"" );
00593         }
00595         virtual void logit ( logger &L ) {
00596                 vec tmp ( Drv._dsize() +Urv._dsize() );
00597                 getdata ( tmp );
00598                 
00599                 L.logit ( L_dt,tmp.left ( Drv._dsize() ) );
00600                 
00601                 L.logit ( L_ut,tmp.mid ( Drv._dsize(), Urv._dsize() ) );
00602         }
00604         virtual RV _drv() const {return concat ( Drv,Urv );}
00606         const RV& _urv() const {return Urv;}
00607 };
00608 
00630 class BM :public bdmroot {
00631 protected:
00633         RV drv;
00635         double ll;
00637         bool evalll;
00638 public:
00641 
00642         BM () :ll ( 0 ),evalll ( true ), LIDs ( 3 ), opt_L_bounds ( false ) {};
00643         BM ( const BM &B ) :  drv ( B.drv ), ll ( B.ll ), evalll ( B.evalll ) {}
00646         virtual BM* _copy_ () const {return NULL;};
00648 
00651 
00655         virtual void bayes ( const vec &dt ) = 0;
00657         virtual void bayesB ( const mat &Dt );
00660         virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00662         vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00663 
00665         virtual epdf* epredictor ( ) const {it_error ( "Not implemented" );return NULL;};
00667         virtual mpdf* predictor ( ) const {it_error ( "Not implemented" );return NULL;};
00669 
00674 
00676         RV rvc;
00678         const RV& _rvc() const {return rvc;}
00679 
00681         virtual void condition ( const vec &val ) {it_error ( "Not implemented!" );};
00682 
00684 
00685 
00688 
00689         const RV& _drv() const {return drv;}
00690         void set_drv ( const RV &rv ) {drv=rv;}
00691         void set_rv ( const RV &rv ) {const_cast<epdf&> ( posterior() ).set_rv ( rv );}
00692         double _ll() const {return ll;}
00693         void set_evalll ( bool evl0 ) {evalll=evl0;}
00694         virtual const epdf& posterior() const =0;
00695         virtual const epdf* _e() const =0;
00697 
00700 
00702         void set_options ( const string &opt ) {
00703                 opt_L_bounds= ( opt.find ( "logbounds" ) !=string::npos );
00704         }
00706         ivec LIDs;
00707 
00709         bool opt_L_bounds;
00711         virtual void log_add ( logger &L, const string &name="" ) {
00712                 
00713                 RV r;
00714                 if ( posterior().isnamed() ) {r=posterior()._rv();}
00715                 else{r=RV ( "est", posterior().dimension() );};
00716 
00717                 
00718                 LIDs ( 0 ) =L.add ( r,name );
00719                 if ( opt_L_bounds ) {
00720                         LIDs ( 1 ) =L.add ( r,name+"_lb" );
00721                         LIDs ( 2 ) =L.add ( r,name+"_ub" );
00722                 }
00723         }
00724         virtual void logit ( logger &L ) {
00725                 L.logit ( LIDs ( 0 ), posterior().mean() );
00726                 if ( opt_L_bounds ) {
00727                         vec ub,lb;
00728                         posterior().qbounds ( lb,ub );
00729                         L.logit ( LIDs ( 1 ), lb );
00730                         L.logit ( LIDs ( 2 ), ub );
00731                 }
00732         }
00734 };
00735 
00736 }; 
00737 #endif // BM_H