00001 
00017 #ifndef BM_H
00018 #define BM_H
00019 
00020 #include <itpp/itbase.h>
00021 #include "../itpp_ext.h"
00022 
00023 
00024 using namespace itpp;
00025 
00027 class str {
00028 public:
00030         ivec ids;
00032         ivec times;
00034         str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00035                 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00036         };
00037 };
00038 
00045 class RV {
00046 protected:
00048         int tsize;
00050         int len;
00052         ivec ids;
00054         ivec sizes;
00056         ivec times;
00058         Array<std::string> names;
00059 
00060 private:
00062         void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00063 public:
00065         RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00067         RV ( Array<std::string> in_names, ivec in_sizes );
00069         RV ( Array<std::string> in_names );
00071         RV ();
00072 
00074         friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00075 
00077         int count() const {return tsize;} ;
00079         int length() const {return len;} ;
00080 
00081         
00082 
00084         ivec findself ( const RV &rv2 ) const;
00086         bool equal ( const RV &rv2 ) const;
00088         bool add ( const RV &rv2 );
00090         RV subt ( const RV &rv2 ) const;
00092         RV subselect ( const ivec &ind ) const;
00094         RV operator() ( const ivec &ind ) const;
00096         void t ( int delta );
00098         str tostr() const;
00101         ivec dataind ( const RV &crv ) const;
00104         void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00105 
00107         Array<std::string>& _names() {return names;};
00108 
00110         int id ( int at ) {return ids ( at );};
00112         int size ( int at ) {return sizes ( at );};
00114         int time ( int at ) {return times ( at );};
00116         std::string name ( int at ) {return names ( at );};
00117 
00119         void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00121         void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00123         void set_time ( int at, int time0 ) {times ( at ) =time0;};
00124 
00126         void newids();
00127 };
00128 
00130 RV concat ( const RV &rv1, const RV &rv2 );
00131 
00133 extern RV RV0;
00134 
00136 
00137 class fnc {
00138 protected:
00140         int dimy;
00141 public:
00143         fnc ( int dy ) :dimy ( dy ) {};
00145         virtual vec eval ( const vec &cond ) {
00146                 return vec ( 0 );
00147         };
00148         
00150         virtual void condition(const vec &val){};
00151 
00153         int _dimy() const{return dimy;}
00154 
00156         virtual ~fnc() {};
00157 };
00158 
00159 class mpdf;
00160 
00162 
00163 class epdf {
00164 protected:
00166         RV rv;
00167 public:
00169         epdf() :rv ( ) {};
00170 
00172         epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00173 
00174 
00175 
00176 
00178         virtual vec sample () const =0;
00180         virtual mat sample_m ( int N ) const;
00181         
00183         virtual double evallog ( const vec &val ) const =0;
00184 
00186         virtual vec evallog_m ( const mat &Val ) const {
00187                 vec x ( Val.cols() );
00188                 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evallog ( Val.get_col ( i ) ) ;}
00189                 return x;
00190         }
00192         virtual mpdf* condition ( const RV &rv ) const  {it_warning ( "Not implemented" ); return NULL;}
00194         virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00195 
00197         virtual vec mean() const =0;
00198 
00200         virtual vec variance() const = 0;
00201         
00203         virtual ~epdf() {};
00205         const RV& _rv() const {return rv;}
00207         void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00209 };
00210 
00211 
00213 
00214 
00215 class mpdf {
00216 protected:
00218         RV rv;
00220         RV rvc;
00222         epdf* ep;
00223 public:
00224 
00226         virtual vec samplecond ( const vec &cond, double &ll ) {
00227                 this->condition ( cond );
00228                 vec temp= ep->sample();
00229                 ll=ep->evallog ( temp );return temp;
00230         };
00232         virtual mat samplecond_m ( const vec &cond, vec &ll, int N ) {
00233                 this->condition ( cond );
00234                 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00235                 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evallog ( smp );}
00236                 return temp;
00237         };
00239         virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00240 
00242         virtual double evallogcond ( const vec &dt, const vec &cond ) {double tmp; this->condition ( cond );tmp = ep->evallog ( dt );           it_assert_debug(std::isfinite(tmp),"Infinite value"); return tmp;
00243         };
00244 
00246         virtual vec evallogcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return ep->evallog_m ( Dt );};
00247 
00249         virtual ~mpdf() {};
00250 
00252         mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00254         RV _rvc() const {return rvc;}
00256         RV _rv() const {return rv;}
00258         epdf& _epdf() {return *ep;}
00260         epdf* _e() {return ep;}
00261 };
00262 
00265 class datalink_e2e {
00266 protected:
00268         int valsize;
00270         int valupsize;
00272         ivec v2v_up;
00273 public:
00275         datalink_e2e ( const RV &rv, const RV &rv_up ) :
00276                         valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) )  {
00277                 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00278         }
00280         vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00282         void fill_val ( vec &val_up, const vec &val ) {
00283                 it_assert_debug ( valsize==val.length(),"Wrong val" );
00284                 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00285                 set_subvector ( val_up, v2v_up, val );
00286         }
00287 };
00288 
00290 class datalink_m2e: public datalink_e2e {
00291 protected:
00293         int condsize;
00295         ivec v2c_up;
00297         ivec v2c_lo;
00298 
00299 public:
00301         datalink_m2e ( const RV &rv,  const RV &rvc, const RV &rv_up ) :
00302                         datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00303                 
00304                 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00305         }
00307         vec get_cond ( const vec &val_up ) {
00308                 vec tmp ( condsize );
00309                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00310                 return tmp;
00311         }
00312         void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00313                 it_assert_debug ( valsize==val.length(),"Wrong val" );
00314                 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00315                 set_subvector ( val_up, v2v_up, val );
00316                 set_subvector ( val_up, v2c_up, cond );
00317         }
00318 };
00321 class datalink_m2m: public datalink_m2e {
00322 protected:
00324         ivec c2c_up;
00326         ivec c2c_lo;
00327 public:
00329         datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00330                         datalink_m2e ( rv, rvc, rv_up) {
00331                 
00332                 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00333                 it_assert_debug(c2c_lo.length()+v2c_lo.length()==condsize, "cond is not fully given");
00334         }
00336         vec get_cond ( const vec &val_up, const vec &cond_up ) {
00337                 vec tmp ( condsize );
00338                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00339                 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00340                 return tmp;
00341         }
00343 
00344 };
00345 
00349 class mepdf : public mpdf {
00350 public:
00352         mepdf (const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*>(em);};
00353         void condition ( const vec &cond ) {}
00354 };
00355 
00358 class compositepdf {
00359 protected:
00361         int n;
00363         Array<mpdf*> mpdfs;
00364 public:
00365         compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00367         RV getrv ( bool checkoverlap=false );
00369         void setrvc ( const RV &rv, RV &rvc );
00370 };
00371 
00379 class DS {
00380 protected:
00382         RV Drv;
00384         RV Urv; 
00385 public:
00387         void getdata ( vec &dt );
00389         void getdata ( vec &dt, ivec &indeces );
00391         void write ( vec &ut );
00393         void write ( vec &ut, ivec &indeces );
00399         void linkrvs ( RV &drv, RV &urv );
00400 
00402         void step();
00403 
00404 };
00405 
00410 class BM {
00411 protected:
00413         RV rv;
00415         double ll;
00417         bool evalll;
00418 public:
00419 
00421         BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00422         };
00424         BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00425 
00429         virtual void bayes ( const vec &dt ) = 0;
00431         virtual void bayesB ( const mat &Dt );
00433         virtual const epdf& _epdf() const =0;
00434 
00436         virtual const epdf* _e() const =0;
00437 
00440         virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00442         vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00443 
00445         virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00446 
00448         virtual ~BM() {};
00450         const RV& _rv() const {return rv;}
00452         double _ll() const {return ll;}
00454         void set_evalll ( bool evl0 ) {evalll=evl0;}
00455 
00458         virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00459 };
00460 
00470 class BMcond {
00471 protected:
00473         RV rvc;
00474 public:
00476         virtual void condition ( const vec &val ) =0;
00478         BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00480         virtual ~BMcond() {};
00482         const RV& _rvc() const {return rvc;}
00483 };
00484 
00486 #endif // BM_H