00001
00013 #ifndef DATASOURCE_H
00014 #define DATASOURCE_H
00015
00016
00017 #include "../base/bdmbase.h"
00018 #include "../stat/exp_family.h"
00019 #include "../base/user_info.h"
00020
00021 namespace bdm {
00029 class MemDS : public DS {
00030 protected:
00032 mat Data;
00034 int time;
00036 ivec rowid;
00037
00038 public:
00039 int max_length() {return Data.cols();}
00040 void getdata ( vec &dt );
00041 void getdata ( vec &dt, const ivec &indeces );
00042 void set_rvs ( RV &drv, RV &urv );
00043
00044 void write ( vec &ut ) {
00045 bdm_error ( "MemDS::write is not supported" );
00046 }
00047
00048 void write ( vec &ut, ivec &indices ) {
00049 bdm_error ( "MemDS::write is not supported" );
00050 }
00051
00052 void step();
00054 MemDS () {};
00055 MemDS ( mat &Dat, ivec &rowid0);
00080 void from_setting(const Setting &set){
00081 UI::get(Data, set, "Data", UI::compulsory);
00082 if(!UI::get(time, set,"time", UI::optional)) {time =0;}
00083 if(!UI::get(rowid, set, "rowid",UI::optional)) {rowid =linspace(0,Data.rows()-1);}
00084 shared_ptr<RV> r=UI::build<RV>(set,"drv",UI::optional);
00085 if (!r) {r=new RV();
00086 for (int i=0; i<rowid.length(); i++){ r->add(RV("ch"+num2str(rowid(i)), 1, 0));}
00087 }
00088 set_drv(*r,RV());
00089 dtsize=r->_dsize();
00090 utsize=0;
00091 }
00092 };
00093 UIREGISTER(MemDS);
00094
00100 class EpdfDS: public DS {
00101 protected:
00103 shared_ptr<epdf> iepdf;
00105 vec dt;
00106 public:
00107 void step() {
00108 dt=iepdf->sample();
00109 }
00110 void getdata ( vec &dt_out ) {
00111 dt_out = dt;
00112 }
00113 void getdata ( vec &dt_out, const ivec &ids ) {
00114 dt_out = dt ( ids );
00115 }
00116 const RV& _drv() {
00117 return iepdf->_rv();
00118 }
00119
00127 void from_setting ( const Setting &set ) {
00128 iepdf=UI::build<epdf> ( set,"epdf",UI::compulsory );
00129 bdm_assert(iepdf->isnamed(), "Input epdf must be named, check if RV is given correctly");
00130 dt = zeros(iepdf->dimension());
00131 dtsize=dt.length();
00132 set_drv(iepdf->_rv(),RV());
00133 utsize =0;
00134 validate();
00135 }
00136 void validate() {
00137 dt = iepdf->sample();
00138 }
00139 };
00140 UIREGISTER ( EpdfDS );
00141
00145 class MpdfDS :public DS {
00146 protected:
00148 shared_ptr<mpdf> impdf;
00150 vec yt;
00152 vec ut;
00154 datalink_buffered ut2rgr;
00156 datalink_buffered yt2rgr;
00158 vec rgr;
00159
00160 public:
00161 void step() {
00162 yt2rgr.step(yt);
00163 ut2rgr.filldown ( ut,rgr );
00164 yt2rgr.filldown ( yt,rgr );
00165 yt=impdf->samplecond ( rgr );
00166 ut2rgr.step(ut);
00167 }
00168 void getdata ( vec &dt_out ) {
00169 bdm_assert_debug(dt_out.length()>=utsize+ytsize,"Short output vector");
00170 dt_out.set_subvector(0, yt);
00171 dt_out.set_subvector(ytsize, ut);
00172 }
00173 void write(const vec &ut0){ut=ut0;}
00174
00186 void from_setting ( const Setting &set ) {
00187 impdf=UI::build<mpdf> ( set,"mpdf",UI::compulsory );
00188
00189 Yrv = impdf->_rv();
00190
00191 RV rgrv0=impdf->_rvc().remove_time();
00192
00193 Urv=rgrv0.subt(Yrv);
00194 set_drv(Yrv, Urv);
00195
00196 ut2rgr.set_connection(impdf->_rvc(), Urv);
00197 yt2rgr.set_connection(impdf->_rvc(), Yrv);
00198
00199
00200 shared_ptr<RV> rv_ini=UI::build<RV>(set,"init_rv",UI::optional);
00201 if(rv_ini){
00202 vec val;
00203 UI::get(val, set, "init_values", UI::optional);
00204 if (val.length()!=rv_ini->_dsize()){
00205 bdm_error("init_rv and init_values fields have incompatible sizes");
00206 } else {
00207 ut2rgr.set_history(*rv_ini, val);
00208 yt2rgr.set_history(*rv_ini, val);
00209 }
00210 }
00211
00212 yt = zeros ( impdf->dimension() );
00213 rgr = zeros ( impdf->dimensionc() );
00214 ut = zeros(Urv._dsize());
00215
00216 ytsize=yt.length();
00217 utsize=ut.length();
00218 dtsize = ytsize+utsize;
00219 validate();
00220 }
00221 void validate() {
00222
00223 ut2rgr.filldown ( ut,rgr );
00224 yt2rgr.filldown ( yt,rgr );
00225 yt=impdf->samplecond ( rgr );
00226 }
00227 };
00228 UIREGISTER ( MpdfDS );
00229
00233 class FileDS: public MemDS {
00234
00235 public:
00236 void getdata ( vec &dt ) {
00237 dt = Data.get_col ( time );
00238 }
00239
00240 void getdata ( vec &dt, const ivec &indices ) {
00241 vec tmp = Data.get_col ( time );
00242 dt = tmp ( indices );
00243 }
00244
00246 int ndat() {
00247 return Data.cols();
00248 }
00250 void log_add ( logger &L ) {};
00252 void logit ( logger &L ) {};
00253 };
00254
00261 class ITppFileDS: public FileDS {
00262
00263 public:
00264 ITppFileDS ( const string &fname, const string &varname ) : FileDS() {
00265 it_file it ( fname );
00266 it << Name ( varname );
00267 it >> Data;
00268 time = 0;
00269
00270 };
00271
00272 ITppFileDS () : FileDS() {
00273 };
00274
00275 void from_setting ( const Setting &set );
00276
00277
00278
00279 };
00280
00281 UIREGISTER ( ITppFileDS );
00282 SHAREDPTR ( ITppFileDS );
00283
00291 class CsvFileDS: public FileDS {
00292
00293 public:
00295 CsvFileDS ( const string& fname, const string& orientation = "BY_COL" );
00296 };
00297
00298
00299
00304 class ArxDS : public DS {
00305 protected:
00307 RV Rrv;
00309 vec H;
00311 vec U;
00313 vec rgr;
00315 datalink rgrlnk;
00317 mlnorm<chmat> model;
00319 bool opt_L_theta;
00321 int L_theta;
00322 int L_R;
00323 int dt_size;
00324 public:
00325 void getdata ( vec &dt ) {
00326 dt = H;
00327 }
00328
00329 void getdata ( vec &dt, const ivec &indices ) {
00330 dt = H ( indices );
00331 }
00332
00333 void write ( vec &ut ) {
00334 U = ut;
00335 }
00336
00337 void write ( vec &ut, const ivec &indices ) {
00338 bdm_assert_debug ( ut.length() == indices.length(), "ArxDS" );
00339 set_subvector ( U, indices, ut );
00340 }
00341
00342 void step();
00343
00345 ArxDS ( ) {};
00347 void set_parameters ( const mat &Th0, const vec mu0, const chmat &sqR0 ) {
00348 model.set_parameters ( Th0, mu0, sqR0 );
00349 };
00351 void set_drv ( const RV &yrv, const RV &urv, const RV &rrv ) {
00352 Rrv = rrv;
00353 Urv = urv;
00354 dt_size = yrv._dsize() + urv._dsize();
00355
00356 RV drv = concat ( yrv, urv );
00357 Drv = drv;
00358 int td = rrv.mint();
00359 H.set_size ( drv._dsize() * ( -td + 1 ) );
00360 U.set_size ( Urv._dsize() );
00361 for ( int i = -1; i >= td; i-- ) {
00362 drv.t_plus ( -1 );
00363 Drv.add ( drv );
00364 }
00365 rgrlnk.set_connection ( rrv, Drv );
00366
00367 dtsize = Drv._dsize();
00368 utsize = Urv._dsize();
00369 }
00371 void set_options ( const string &s ) {
00372 opt_L_theta = ( s.find ( "L_theta" ) != string::npos );
00373 };
00374 virtual void log_add ( logger &L ) {
00375
00376 L_dt = L.add ( Drv ( 0, dt_size ), "" );
00377 L_ut = L.add ( Urv, "" );
00378
00379 const mat &A = model._A();
00380 const mat R = model._R();
00381 if ( opt_L_theta ) {
00382 L_theta = L.add ( RV ( "{th }", vec_1 ( A.rows() * A.cols() ) ), "t" );
00383 }
00384 if ( opt_L_theta ) {
00385 L_R = L.add ( RV ( "{R }", vec_1 ( R.rows() * R.cols() ) ), "r" );
00386 }
00387 }
00388 virtual void logit ( logger &L ) {
00389
00390 L.logit ( L_dt, H.left ( dt_size ) );
00391 L.logit ( L_ut, U );
00392
00393 const mat &A = model._A();
00394 const mat R = model._R();
00395 if ( opt_L_theta ) {
00396 L.logit ( L_theta, vec ( A._data(), A.rows() *A.cols() ) );
00397 };
00398 if ( opt_L_theta ) {
00399 L.logit ( L_R, vec ( R._data(), R.rows() *R.rows() ) );
00400 };
00401 }
00402
00403
00435 void from_setting ( const Setting &set );
00436
00437
00438 };
00439
00440 UIREGISTER ( ArxDS );
00441 SHAREDPTR ( ArxDS );
00442
00443 class stateDS : public DS {
00444 private:
00446 shared_ptr<mpdf> IM;
00447
00449 shared_ptr<mpdf> OM;
00450
00451 protected:
00453 vec dt;
00455 vec xt;
00457 vec ut;
00459 int L_xt;
00460
00461 public:
00462 void getdata ( vec &dt0 ) {
00463 dt0 = dt;
00464 }
00465
00466 void getdata ( vec &dt0, const ivec &indices ) {
00467 dt0 = dt ( indices );
00468 }
00469
00470 stateDS ( const shared_ptr<mpdf> &IM0, const shared_ptr<mpdf> &OM0, int usize ) : IM ( IM0 ), OM ( OM0 ),
00471 dt ( OM0->dimension() ), xt ( IM0->dimension() ),
00472 ut ( usize ), L_xt ( 0 ) { }
00473
00474 stateDS() : L_xt ( 0 ) { }
00475
00476 virtual void step() {
00477 xt = IM->samplecond ( concat ( xt, ut ) );
00478 dt = OM->samplecond ( concat ( xt, ut ) );
00479 }
00480
00481 virtual void log_add ( logger &L ) {
00482 DS::log_add ( L );
00483 L_xt = L.add ( IM->_rv(), "true" );
00484 }
00485 virtual void logit ( logger &L ) {
00486 DS::logit ( L );
00487 L.logit ( L_xt, xt );
00488 }
00489
00519 void from_setting ( const Setting &set );
00520
00521
00522
00523 };
00524
00525 UIREGISTER ( stateDS );
00526 SHAREDPTR ( stateDS );
00527
00528 };
00529
00530 #endif // DS_H