00001 
00013 #ifndef BM_H
00014 #define BM_H
00015 
00016 #include <itpp/itbase.h>
00017 #include "../itpp_ext.h"
00018 
00019 
00020 using namespace itpp;
00021 
00023 class str {
00024 public:
00026         ivec ids;
00028         ivec times;
00030         str ( ivec ids0, ivec times0 ) :ids ( ids0 ),times ( times0 ) {
00031                 it_assert_debug ( times0.length() ==ids0.length(),"Incompatible input" );
00032         };
00033 };
00034 
00041 class RV {
00042 protected:
00044         int tsize;
00046         int len;
00048         ivec ids;
00050         ivec sizes;
00052         ivec times;
00054         Array<std::string> names;
00055 
00056 private:
00058         void init ( ivec in_ids, Array<std::string> in_names, ivec in_sizes, ivec in_times );
00059 public:
00061         RV ( Array<std::string> in_names, ivec in_sizes, ivec in_times );
00063         RV ( Array<std::string> in_names, ivec in_sizes );
00065         RV ( Array<std::string> in_names );
00067         RV ();
00068 
00070         friend std::ostream &operator<< ( std::ostream &os, const RV &rv );
00071 
00073         int count() const {return tsize;} ;
00075         int length() const {return len;} ;
00076 
00077         
00078 
00080         ivec findself ( const RV &rv2 ) const;
00082         bool equal ( const RV &rv2 ) const;
00084         bool add ( const RV &rv2 );
00086         RV subt ( const RV &rv2 ) const;
00088         RV subselect ( const ivec &ind ) const;
00090         RV operator() ( const ivec &ind ) const;
00092         void t ( int delta );
00094         str tostr() const;
00097         ivec dataind ( const RV &crv ) const;
00100         void dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const;
00101 
00103         Array<std::string>& _names() {return names;};
00104 
00106         int id ( int at ) {return ids ( at );};
00108         int size ( int at ) {return sizes ( at );};
00110         int time ( int at ) {return times ( at );};
00112         std::string name ( int at ) {return names ( at );};
00113 
00115         void set_id ( int at, int id0 ) {ids ( at ) =id0;};
00117         void set_size ( int at, int size0 ) {sizes ( at ) =size0; tsize=sum ( sizes );};
00119         void set_time ( int at, int time0 ) {times ( at ) =time0;};
00120 
00122         void newids();
00123 };
00124 
00126 RV concat ( const RV &rv1, const RV &rv2 );
00127 
00128 
00130 
00131 class fnc {
00132 protected:
00134         int dimy;
00135 public:
00137         fnc ( int dy ) :dimy ( dy ) {};
00139         virtual vec eval ( const vec &cond ) {
00140                 return vec ( 0 );
00141         };
00142 
00144         int _dimy() const{return dimy;}
00145 
00147         virtual ~fnc() {};
00148 };
00149 
00150 class mpdf;
00151 
00153 
00154 class epdf {
00155 protected:
00157         RV rv;
00158 public:
00160         epdf() :rv ( ) {};
00161 
00163         epdf ( const RV &rv0 ) :rv ( rv0 ) {};
00164 
00165 
00166 
00167 
00169         virtual vec sample () const =0;
00171         virtual mat sample_m ( int N ) const;
00172         
00174         virtual double evalpdflog ( const vec &val ) const =0;
00175 
00177         virtual vec evalpdflog_m ( const mat &Val ) const {
00178                 vec x ( Val.cols() );
00179                 for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evalpdflog ( Val.get_col ( i ) ) ;}
00180                 return x;
00181         }
00183         virtual mpdf* condition ( const RV &rv ) const  {it_warning ( "Not implemented" ); return NULL;}
00185         virtual epdf* marginal ( const RV &rv ) const {it_warning ( "Not implemented" ); return NULL;}
00186 
00188         virtual vec mean() const =0;
00189 
00191         virtual ~epdf() {};
00193         const RV& _rv() const {return rv;}
00195         void _renewrv ( const RV &in_rv ) {rv=in_rv;}
00197 };
00198 
00199 
00201 
00202 
00203 class mpdf {
00204 protected:
00206         RV rv;
00208         RV rvc;
00210         epdf* ep;
00211 public:
00212 
00214 
00216         virtual vec samplecond ( const vec &cond, double &ll ) {
00217                 this->condition ( cond );
00218                 vec temp= ep->sample();
00219                 ll=ep->evalpdflog ( temp );return temp;
00220         };
00222         virtual mat samplecond ( const vec &cond, vec &ll, int N ) {
00223                 this->condition ( cond );
00224                 mat temp ( rv.count(),N ); vec smp ( rv.count() );
00225                 for ( int i=0;i<N;i++ ) {smp=ep->sample() ;temp.set_col ( i, smp );ll ( i ) =ep->evalpdflog ( smp );}
00226                 return temp;
00227         };
00229         virtual void condition ( const vec &cond ) {it_error ( "Not implemented" );};
00230 
00232         virtual double evalcond ( const vec &dt, const vec &cond ) {this->condition ( cond );return exp(ep->evalpdflog ( dt ));};
00233 
00234         virtual vec evalcond_m ( const mat &Dt, const vec &cond ) {this->condition ( cond );return exp(ep->evalpdflog_m ( Dt ));};
00235 
00237         virtual ~mpdf() {};
00238 
00240         mpdf ( const RV &rv0, const RV &rvc0 ) :rv ( rv0 ),rvc ( rvc0 ) {};
00242         RV _rvc() const {return rvc;}
00244         RV _rv() const {return rv;}
00246         epdf& _epdf() {return *ep;}
00247 };
00248 
00251 class datalink_e2e {
00252 protected:
00254         int valsize;
00256         int valupsize;
00258         ivec v2v_up;
00259 public:
00261         datalink_e2e ( const RV &rv, const RV &rv_up ) :
00262                         valsize ( rv.count() ), valupsize ( rv_up.count() ), v2v_up ( rv.dataind ( rv_up ) )  {
00263                 it_assert_debug ( v2v_up.length() ==valsize,"rv is not fully in rv_up" );
00264         }
00266         vec get_val ( const vec &val_up ) {it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" ); return get_vec ( val_up,v2v_up );}
00268         void fill_val ( vec &val_up, const vec &val ) {
00269                 it_assert_debug ( valsize==val.length(),"Wrong val" );
00270                 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00271                 set_subvector ( val_up, v2v_up, val );
00272         }
00273 };
00274 
00276 class datalink_m2e: public datalink_e2e {
00277 protected:
00279         int condsize;
00281         ivec v2c_up;
00283         ivec v2c_lo;
00284 
00285 public:
00287         datalink_m2e ( const RV &rv,  const RV &rvc, const RV &rv_up ) :
00288                         datalink_e2e ( rv,rv_up ), condsize ( rvc.count() ) {
00289                 
00290                 rvc.dataind ( rv_up, v2c_lo, v2c_up );
00291         }
00293         vec get_cond ( const vec &val_up ) {
00294                 vec tmp ( condsize );
00295                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00296                 return tmp;
00297         }
00298         void fill_val_cond ( vec &val_up, const vec &val, const vec &cond ) {
00299                 it_assert_debug ( valsize==val.length(),"Wrong val" );
00300                 it_assert_debug ( valupsize==val_up.length(),"Wrong val_up" );
00301                 set_subvector ( val_up, v2v_up, val );
00302                 set_subvector ( val_up, v2c_up, cond );
00303         }
00304 };
00307 class datalink_m2m: public datalink_m2e {
00308 protected:
00310         ivec c2c_up;
00312         ivec c2c_lo;
00313 public:
00315         datalink_m2m ( const RV &rv, const RV &rvc, const RV &rv_up, const RV &rvc_up ) :
00316                         datalink_m2e ( rv, rvc, rv_up) {
00317                 
00318                 rvc.dataind ( rvc_up, c2c_lo, c2c_up );
00319                 it_assert_debug(c2c_lo.length()+v2c_lo.length()==condsize, "cond is not fully given");
00320         }
00322         vec get_cond ( const vec &val_up, const vec &cond_up ) {
00323                 vec tmp ( condsize );
00324                 set_subvector ( tmp,v2c_lo,val_up ( v2c_up ) );
00325                 set_subvector ( tmp,c2c_lo,cond_up ( c2c_up ) );
00326                 return tmp;
00327         }
00329 
00330 };
00331 
00335 class mepdf : public mpdf {
00336 public:
00338         mepdf (const epdf* em ) :mpdf ( em->_rv(),RV() ) {ep=const_cast<epdf*>(em);};
00339         void condition ( const vec &cond ) {}
00340 };
00341 
00344 class compositepdf {
00345 protected:
00347         int n;
00349         Array<mpdf*> mpdfs;
00350 public:
00351         compositepdf ( Array<mpdf*> A0 ) : n ( A0.length() ), mpdfs ( A0 ) {};
00353         RV getrv ( bool checkoverlap=false );
00355         void setrvc ( const RV &rv, RV &rvc );
00356 };
00357 
00365 class DS {
00366 protected:
00368         RV Drv;
00370         RV Urv; 
00371 public:
00373         void getdata ( vec &dt );
00375         void getdata ( vec &dt, ivec &indeces );
00377         void write ( vec &ut );
00379         void write ( vec &ut, ivec &indeces );
00385         void linkrvs ( RV &drv, RV &urv );
00386 
00388         void step();
00389 
00390 };
00391 
00396 class BM {
00397 protected:
00399         RV rv;
00401         double ll;
00403         bool evalll;
00404 public:
00405 
00407         BM ( const RV &rv0, double ll0=0,bool evalll0=true ) :rv ( rv0 ), ll ( ll0 ),evalll ( evalll0 ) {
00408         };
00410         BM ( const BM &B ) : rv ( B.rv ), ll ( B.ll ), evalll ( B.evalll ) {}
00411 
00415         virtual void bayes ( const vec &dt ) = 0;
00417         virtual void bayesB ( const mat &Dt );
00419         virtual const epdf& _epdf() const =0;
00420 
00422         virtual const epdf* _e() const =0;
00423 
00426         virtual double logpred ( const vec &dt ) const{it_error ( "Not implemented" );return 0.0;}
00428         vec logpred_m ( const mat &dt ) const{vec tmp ( dt.cols() );for ( int i=0;i<dt.cols();i++ ) {tmp ( i ) =logpred ( dt.get_col ( i ) );}return tmp;}
00429 
00431         virtual epdf* predictor ( const RV &rv ) const {it_error ( "Not implemented" );return NULL;};
00432 
00434         virtual ~BM() {};
00436         const RV& _rv() const {return rv;}
00438         double _ll() const {return ll;}
00440         void set_evalll ( bool evl0 ) {evalll=evl0;}
00441 
00444         virtual BM* _copy_ ( bool changerv=false ) {it_error ( "function _copy_ not implemented for this BM" ); return NULL;};
00445 };
00446 
00456 class BMcond {
00457 protected:
00459         RV rvc;
00460 public:
00462         virtual void condition ( const vec &val ) =0;
00464         BMcond ( RV &rv0 ) :rvc ( rv0 ) {};
00466         virtual ~BMcond() {};
00468         const RV& _rvc() const {return rvc;}
00469 };
00470 
00471 #endif // BM_H