00001 
00013 #ifndef BM_H
00014 #define BM_H
00015 
00016 #include <map>
00017 
00018 #include "../itpp_ext.h"
00019 #include "../bdmroot.h"
00020 #include "../user_info.h"
00021 
00022 
00023 using namespace libconfig;
00024 using namespace itpp;
00025 using namespace std;
00026 
00027 namespace bdm {
00028 
00029 typedef std::map<string, int> RVmap;
00030 extern ivec RV_SIZES;
00031 extern Array<string> RV_NAMES;
00032 
00034 class str {
00035 public:
00037         ivec ids;
00039         ivec times;
00041         str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00042                 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00043         };
00044 };
00045 
00084 class RV :public bdmroot {
00085 protected:
00087         int dsize;
00089         int len;
00091         ivec ids;
00093         ivec times;
00094 
00095 private:
00097         void init ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00098         int init ( const  string &name, int size );
00099 public:
00102 
00104         RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times ) {init ( in_names,in_sizes,in_times );};
00106         RV ( Array<std::string> in_names, ivec in_sizes ) {init ( in_names,in_sizes,zeros_i ( in_names.length() ) );};
00108         RV ( Array<std::string> in_names ) {init ( in_names,ones_i ( in_names.length() ),zeros_i ( in_names.length() ) );}
00110         RV () :dsize ( 0 ),len ( 0 ),ids ( 0 ),times ( 0 ) {};
00112         RV ( string name, int sz, int tm=0 );
00114 
00117 
00119         friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00120         int _dsize() const {return dsize;} ;
00122         int countsize() const;
00123         ivec cumsizes() const;
00124         int length() const {return len;} ;
00125         int id ( int at ) const{return ids ( at );};
00126         int size ( int at ) const {return RV_SIZES ( ids ( at ) );};
00127         int time ( int at ) const{return times ( at );};
00128         std::string name ( int at ) const {return RV_NAMES ( ids ( at ) );};
00129         void set_time ( int at, int time0 ) {times ( at ) =time0;};
00131 
00132         
00133 
00136 
00138         ivec findself ( const RV &rv2 ) const;
00140         bool equal ( const RV &rv2 ) const;
00142         bool add ( const RV &rv2 );
00144         RV subt ( const RV &rv2 ) const;
00146         RV subselect ( const ivec &ind ) const;
00148         RV operator() ( const ivec &ind ) const {return subselect ( ind );};
00150         RV operator() ( int di1, int di2 ) const {
00151                 ivec sz=cumsizes();
00152                 int i1=0;
00153                 while ( sz ( i1 ) <di1 ) i1++;
00154                 int i2=i1;
00155                 while ( sz ( i2 ) <di2 ) i2++;
00156                 return subselect ( linspace ( i1,i2 ) );
00157         };
00159         void t ( int delta );
00161 
00164 
00166         str tostr() const;
00169         ivec dataind ( const RV &crv ) const;
00172         void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00174         int mint () const {return min ( times );};
00176 
00177         
00192         void from_setting( const Setting &root );
00193 
00194         
00195 };
00196 UIREGISTER(RV);
00197 
00199 RV concat ( const RV &rv1, const RV &rv2 );
00200 
00202 extern RV RV0;
00203 
00205 
00206 class fnc :public bdmroot {
00207 protected:
00209         int dimy;
00210 public:
00212         fnc ( ) {};
00214         virtual vec eval ( const vec &cond ) {
00215                 return vec ( 0 );
00216         };
00217 
00219         virtual void condition ( const vec &val ) {};
00220 
00222         int dimension() const{return dimy;}
00223 };
00224 
00225 class mpdf;
00226 
00228 
00229 class epdf : public bdmroot {
00230 protected:
00232         int dim;
00234         RV rv;
00235 
00236 public:
00248         epdf() :dim ( 0 ),rv ( ) {};
00249         epdf ( const epdf &e ) :dim ( e.dim ),rv ( e.rv ) {};
00250         epdf ( const RV &rv0 ) {set_rv ( rv0 );};
00251         void set_parameters ( int dim0 ) {dim=dim0;}
00253 
00256 
00258         virtual vec sample () const {it_error ( "not implemneted" );return vec ( 0 );};
00260         virtual mat sample_m ( int N ) const;
00262         virtual double evallog ( const vec &val ) const {it_error ( "not implemneted" );return 0.0;};
00264         virtual vec evallog_m ( const mat &Val ) const {
00265                 vec x ( Val.cols() );
00266                 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00267                 return x;
00268         }
00270         virtual mpdf* condition ( const RV &rv ) const  {it_warning ( "Not implemented" ); return NULL;}
00272         virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00274         virtual vec mean() const {it_error ( "not implemneted" );return vec ( 0 );};
00276         virtual vec variance() const {it_error ( "not implemneted" );return vec ( 0 );};
00278         virtual void qbounds ( vec &lb, vec &ub, double percentage=0.95 ) const {
00279                 vec mea=mean(); vec std=sqrt ( variance() );
00280                 lb = mea-2*std; ub=mea+2*std;
00281         };
00283 
00289 
00291         void set_rv ( const RV &rv0 ) {rv = rv0; }
00293         bool isnamed() const {bool b= ( dim==rv._dsize() );return b;}
00295         const RV& _rv() const {it_assert_debug ( isnamed(),"" ); return rv;}
00297 
00300 
00302         int dimension() const {return dim;}
00304 
00305 };
00306 
00307 
00309 
00310 
00311 class mpdf : public bdmroot {
00312 protected:
00314         int dimc;
00316         RV rvc;
00318         epdf* ep;
00319 public:
00322 
00323         mpdf ( ) :dimc ( 0 ),rvc ( ) {};
00325         mpdf ( const mpdf &m ) :dimc ( m.dimc ),rvc ( m.rvc ) {};
00327 
00330 
00332         virtual vec samplecond ( const vec &cond ) {
00333                 this->condition ( cond );
00334                 vec temp= ep->sample();
00335                 return temp;
00336         };
00338         virtual mat samplecond_m ( const vec &cond, int N ) {
00339                 this->condition ( cond );
00340                 mat temp ( ep->dimension(),N ); vec smp ( ep->dimension() );
00341                 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );}
00342                 return temp;
00343         };
00345         virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00346 
00348         virtual double evallogcond ( const vec &dt, const vec &cond ) {
00349                 double tmp; this->condition ( cond );tmp = ep->evallog ( dt );          it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;
00350         };
00351 
00353         virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00354 
00357 
00358         RV _rv() {return ep->_rv();}
00359         RV _rvc() {it_assert_debug ( isnamed(),"" ); return rvc;}
00360         int dimension() {return ep->dimension();}
00361         int dimensionc() {return dimc;}
00362         epdf& _epdf() {return *ep;}
00363         epdf* _e() {return ep;}
00365 
00368         void set_rvc ( const RV &rvc0 ) {rvc=rvc0;}
00369         void set_rv ( const RV &rv0 ) {ep->set_rv ( rv0 );}
00370         bool isnamed() {return ( ep->isnamed() ) && ( dimc==rvc._dsize() );}
00372 };
00373 
00399 class datalink {
00400 protected:
00402         int downsize;
00404         int upsize;
00406         ivec v2v_up;
00407 public:
00409         datalink () {};
00410         datalink ( const RV &rv, const RV &rv_up ) {set_connection ( rv,rv_up );};
00412         void set_connection ( const RV &rv, const RV &rv_up ) {
00413                 downsize = rv._dsize();
00414                 upsize = rv_up._dsize();
00415                 v2v_up= ( rv.dataind ( rv_up ) );
00416 
00417                 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00418         }
00420         void set_connection ( int ds, int us, const ivec &upind ) {
00421                 downsize = ds;
00422                 upsize = us;
00423                 v2v_up= upind;
00424 
00425                 it_assert_debug ( v2v_up.length() ==downsize,"rv is not fully in rv_up" );
00426         }
00428         vec pushdown ( const vec &val_up ) {
00429                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00430                 return get_vec ( val_up,v2v_up );
00431         }
00433         void pushup ( vec &val_up, const vec &val ) {
00434                 it_assert_debug ( downsize==val.length(),"Wrong val" );
00435                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00436                 set_subvector ( val_up, v2v_up, val );
00437         }
00438 };
00439 
00441 class datalink_m2e: public datalink {
00442 protected:
00444         int condsize;
00446         ivec v2c_up;
00448         ivec v2c_lo;
00449 
00450 public:
00451         datalink_m2e() {};
00453         void set_connection ( const RV &rv,  const RV &rvc, const RV &rv_up ) {
00454                 datalink::set_connection ( rv,rv_up );
00455                 condsize=  rvc._dsize();
00456                 
00457                 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00458         }
00460         vec get_cond ( const vec &val_up ) {
00461                 vec tmp ( condsize );
00462                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00463                 return tmp;
00464         }
00465         void pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
00466                 it_assert_debug ( downsize==val.length(),"Wrong val" );
00467                 it_assert_debug ( upsize==val_up.length(),"Wrong val_up" );
00468                 set_subvector ( val_up, v2v_up, val );
00469                 set_subvector ( val_up, v2c_up, cond );
00470         }
00471 };
00474 class datalink_m2m: public datalink_m2e {
00475 protected:
00477         ivec c2c_up;
00479         ivec c2c_lo;
00480 public:
00482         datalink_m2m() {};
00483         void set_connection ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) {
00484                 datalink_m2e::set_connection ( rv, rvc, rv_up );
00485                 
00486                 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00487                 it_assert_debug ( c2c_lo.length() +v2c_lo.length() ==condsize, "cond is not fully given" );
00488         }
00490         vec get_cond ( const vec &val_up, const vec &cond_up ) {
00491                 vec tmp ( condsize );
00492                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00493                 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00494                 return tmp;
00495         }
00497 
00498 };
00499 
00505 class logger : public bdmroot {
00506 protected:
00508         Array<RV> entries;
00510         Array<string> names;
00511 public:
00513         logger ( ) : entries ( 0 ),names ( 0 ) {}
00514 
00517         virtual int add ( const RV &rv, string prefix="" ) {
00518                 int id;
00519                 if ( rv._dsize() >0 ) {
00520                         id=entries.length();
00521                         names=concat ( names, prefix); 
00522                         entries.set_length ( id+1,true );
00523                         entries ( id ) = rv;
00524                 }
00525                 else { id =-1;}
00526                 return id; 
00527         }
00528 
00530         virtual void logit ( int id, const vec &v ) =0;
00532         virtual void logit ( int id, const double &d ) =0;
00533 
00535         virtual void step() =0;
00536 
00538         virtual void finalize() {};
00539 
00541         virtual void init() {};
00542 
00543 };
00544 
00548 class mepdf : public mpdf {
00549 public:
00551         mepdf ( epdf* em ) :mpdf ( ) {ep= em ;};
00552         mepdf (const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*>( em );};
00553         void condition ( const vec &cond ) {}
00554 };
00555 
00558 class compositepdf : public bdmroot{
00559 protected:
00561         Array<mpdf*> mpdfs;
00562 public:
00563         compositepdf():mpdfs(0){};
00564         void set_elements (const Array<mpdf*> A0 ) { mpdfs = A0;};
00566         RV getrv ( bool checkoverlap=false );
00568         void setrvc ( const RV &rv, RV &rvc );
00570 };
00571 
00579 class DS : public bdmroot {
00580 protected:
00581         int dtsize;
00582         int utsize;
00584         RV Drv;
00586         RV Urv; 
00588         int L_dt, L_ut;
00589 public:
00591         DS() :Drv ( ),Urv ( ) {};
00593         virtual void getdata ( vec &dt ) {it_error ( "abstract class" );};
00595         virtual void getdata ( vec &dt, const ivec &indeces ) {it_error ( "abstract class" );};
00597         virtual void write ( vec &ut ) {it_error ( "abstract class" );};
00599         virtual void write ( vec &ut, const ivec &indeces ) {it_error ( "abstract class" );};
00600 
00602         virtual void step() =0;
00603 
00605         virtual void log_add ( logger &L ) {
00606                 it_assert_debug ( dtsize==Drv._dsize(),"" );
00607                 it_assert_debug ( utsize==Urv._dsize(),"" );
00608 
00609                 L_dt=L.add ( Drv,"" );
00610                 L_ut=L.add ( Urv,"" );
00611         }
00613         virtual void logit ( logger &L ) {
00614                 vec tmp ( Drv._dsize() +Urv._dsize() );
00615                 getdata ( tmp );
00616                 
00617                 L.logit ( L_dt,tmp.left ( Drv._dsize() ) );
00618                 
00619                 L.logit ( L_ut,tmp.mid ( Drv._dsize(), Urv._dsize() ) );
00620         }
00622         virtual RV _drv() const {return concat ( Drv,Urv );}
00624         const RV& _urv() const {return Urv;}
00626         virtual void set_drv (const  RV &drv, const RV &urv) { Drv=drv;Urv=urv;}
00627 };
00628 
00650 class BM :public bdmroot {
00651 protected:
00653         RV drv;
00655         double ll;
00657         bool evalll;
00658 public:
00661 
00662         BM () :ll ( 0 ),evalll ( true ), LIDs ( 4 ), LFlags(4) {
00663                 LIDs=-1; LFlags=0; LFlags(0)=1;};
00664         BM ( const BM &B ) :  drv ( B.drv ), ll ( B.ll ), evalll ( B.evalll ) {}
00667         virtual BM* _copy_ () const {return NULL;};
00669 
00672 
00676         virtual void bayes ( const vec &dt ) = 0;
00678         virtual void bayesB ( const mat &Dt );
00681         virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00683         vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00684 
00686         virtual epdf* epredictor ( ) const {it_error ( "Not implemented" );return NULL;};
00688         virtual mpdf* predictor ( ) const {it_error ( "Not implemented" );return NULL;};
00690 
00695 
00697         RV rvc;
00699         const RV& _rvc() const {return rvc;}
00700 
00702         virtual void condition ( const vec &val ) {it_error ( "Not implemented!" );};
00703 
00705 
00706 
00709 
00710         const RV& _drv() const {return drv;}
00711         void set_drv ( const RV &rv ) {drv=rv;}
00712         void set_rv ( const RV &rv ) {const_cast<epdf&> ( posterior() ).set_rv ( rv );}
00713         double _ll() const {return ll;}
00714         void set_evalll ( bool evl0 ) {evalll=evl0;}
00715         virtual const epdf& posterior() const =0;
00716         virtual const epdf* _e() const =0;
00718 
00721 
00723         virtual void set_options ( const string &opt ) {
00724                 LFlags(0)=1;
00725                 if ( opt.find ( "logbounds" ) !=string::npos ) {LFlags(1)=1; LFlags(2)=1;}
00726                 if ( opt.find ( "logll" ) !=string::npos ) {LFlags(3)=1;}
00727         }
00729         ivec LIDs;
00730 
00732         ivec LFlags;
00734         virtual void log_add ( logger &L, const string &name="" ) {
00735                 
00736                 RV r;
00737                 if ( posterior().isnamed() ) {r=posterior()._rv();}
00738                 else{r=RV ( "est", posterior().dimension() );};
00739 
00740                 
00741                 if (LFlags(0)) LIDs ( 0 ) =L.add ( r,name+"mean_" );
00742                 if (LFlags(1)) LIDs ( 1 ) =L.add ( r,name+"lb_" );
00743                 if (LFlags(2)) LIDs ( 2 ) =L.add ( r,name+"ub_" );
00744                 if (LFlags(3)) LIDs ( 3 ) =L.add ( RV("ll",1),name ); 
00745         }
00746         virtual void logit ( logger &L ) {
00747                 L.logit ( LIDs ( 0 ), posterior().mean() );
00748                 if ( LFlags(1) || LFlags(2)) { 
00749                         vec ub,lb;
00750                         posterior().qbounds ( lb,ub );
00751                         L.logit ( LIDs ( 1 ), lb ); 
00752                         L.logit ( LIDs ( 2 ), ub );
00753                 }
00754                 if (LFlags(3)) L.logit ( LIDs ( 3 ), ll );
00755         }
00757 };
00758 
00759 
00760 }; 
00761 #endif // BM_H