00001
00013 #ifndef DATASOURCE_H
00014 #define DATASOURCE_H
00015
00016
00017 #include "../base/bdmbase.h"
00018 #include "../stat/exp_family.h"
00019 #include "../base/user_info.h"
00020
00021 namespace bdm {
00029 class MemDS : public DS {
00030 protected:
00032 mat Data;
00034 int time;
00036 ivec rowid;
00037
00038 public:
00039 int max_length() {return Data.cols();}
00040 void getdata ( vec &dt );
00041 void getdata ( vec &dt, const ivec &indeces );
00042 void set_rvs ( RV &drv, RV &urv );
00043
00044 void write ( vec &ut ) {
00045 bdm_error ( "MemDS::write is not supported" );
00046 }
00047
00048 void write ( vec &ut, ivec &indices ) {
00049 bdm_error ( "MemDS::write is not supported" );
00050 }
00051
00052 void step();
00054 MemDS () {};
00055 MemDS ( mat &Dat, ivec &rowid0);
00080 void from_setting(const Setting &set){
00081 UI::get(Data, set, "Data", UI::compulsory);
00082 if(!UI::get(time, set,"time", UI::optional)) {time =0;}
00083 if(!UI::get(rowid, set, "rowid",UI::optional)) {rowid =linspace(0,Data.rows()-1);}
00084 shared_ptr<RV> r=UI::build<RV>(set,"drv",UI::optional);
00085 if (!r) {r=new RV();
00086 for (int i=0; i<rowid.length(); i++){ r->add(RV("ch"+num2str(rowid(i)), 1, 0));}
00087 }
00088 set_drv(*r,RV());
00089 dtsize=r->_dsize();
00090 utsize=0;
00091 }
00092 };
00093 UIREGISTER(MemDS);
00094
00100 class EpdfDS: public DS {
00101 protected:
00103 shared_ptr<epdf> iepdf;
00105 vec dt;
00106 public:
00107 void step() {
00108 dt=iepdf->sample();
00109 }
00110 void getdata ( vec &dt_out ) {
00111 dt_out = dt;
00112 }
00113 void getdata ( vec &dt_out, const ivec &ids ) {
00114 dt_out = dt ( ids );
00115 }
00116 const RV& _drv() {
00117 return iepdf->_rv();
00118 }
00119
00127 void from_setting ( const Setting &set ) {
00128 iepdf=UI::build<epdf> ( set,"epdf",UI::compulsory );
00129 bdm_assert(iepdf->isnamed(), "Input epdf must be named, check if RV is given correctly");
00130 dt = zeros(iepdf->dimension());
00131 dtsize=dt.length();
00132 set_drv(iepdf->_rv(),RV());
00133 utsize =0;
00134 }
00135 void validate() {
00136 dt = iepdf->sample();
00137 }
00138 };
00139 UIREGISTER ( EpdfDS );
00140
00144 class MpdfDS :public DS {
00145 protected:
00147 shared_ptr<mpdf> impdf;
00149 vec yt;
00151 vec ut;
00153 datalink_buffered ut2rgr;
00155 datalink_buffered yt2rgr;
00157 vec rgr;
00158
00159 public:
00160 void step() {
00161 yt2rgr.step(yt);
00162 ut2rgr.filldown ( ut,rgr );
00163 yt2rgr.filldown ( yt,rgr );
00164 yt=impdf->samplecond ( rgr );
00165 ut2rgr.step(ut);
00166 }
00167 void getdata ( vec &dt_out ) {
00168 bdm_assert_debug(dt_out.length()>=utsize+ytsize,"Short output vector");
00169 dt_out.set_subvector(0, yt);
00170 dt_out.set_subvector(ytsize, ut);
00171 }
00172 void write(const vec &ut0){ut=ut0;}
00173
00185 void from_setting ( const Setting &set ) {
00186 impdf=UI::build<mpdf> ( set,"mpdf",UI::compulsory );
00187
00188 Yrv = impdf->_rv();
00189
00190 RV rgrv0=impdf->_rvc().remove_time();
00191
00192 Urv=rgrv0.subt(Yrv);
00193 set_drv(Yrv, Urv);
00194
00195 ut2rgr.set_connection(impdf->_rvc(), Urv);
00196 yt2rgr.set_connection(impdf->_rvc(), Yrv);
00197
00198
00199 shared_ptr<RV> rv_ini=UI::build<RV>(set,"init_rv",UI::optional);
00200 if(rv_ini){
00201 vec val;
00202 UI::get(val, set, "init_values", UI::optional);
00203 if (val.length()!=rv_ini->_dsize()){
00204 bdm_error("init_rv and init_values fields have incompatible sizes");
00205 } else {
00206 ut2rgr.set_history(*rv_ini, val);
00207 yt2rgr.set_history(*rv_ini, val);
00208 }
00209 }
00210
00211 yt = zeros ( impdf->dimension() );
00212 rgr = zeros ( impdf->dimensionc() );
00213 ut = zeros(Urv._dsize());
00214
00215 ytsize=yt.length();
00216 utsize=ut.length();
00217 dtsize = ytsize+utsize;
00218 validate();
00219 }
00220 void validate() {
00221
00222 ut2rgr.filldown ( ut,rgr );
00223 yt2rgr.filldown ( yt,rgr );
00224 yt=impdf->samplecond ( rgr );
00225 }
00226 };
00227 UIREGISTER ( MpdfDS );
00228
00232 class FileDS: public MemDS {
00233
00234 public:
00235 void getdata ( vec &dt ) {
00236 dt = Data.get_col ( time );
00237 }
00238
00239 void getdata ( vec &dt, const ivec &indices ) {
00240 vec tmp = Data.get_col ( time );
00241 dt = tmp ( indices );
00242 }
00243
00245 int ndat() {
00246 return Data.cols();
00247 }
00249 void log_add ( logger &L ) {};
00251 void logit ( logger &L ) {};
00252 };
00253
00260 class ITppFileDS: public FileDS {
00261
00262 public:
00263 ITppFileDS ( const string &fname, const string &varname ) : FileDS() {
00264 it_file it ( fname );
00265 it << Name ( varname );
00266 it >> Data;
00267 time = 0;
00268
00269 };
00270
00271 ITppFileDS () : FileDS() {
00272 };
00273
00274 void from_setting ( const Setting &set );
00275
00276
00277
00278 };
00279
00280 UIREGISTER ( ITppFileDS );
00281 SHAREDPTR ( ITppFileDS );
00282
00290 class CsvFileDS: public FileDS {
00291
00292 public:
00294 CsvFileDS ( const string& fname, const string& orientation = "BY_COL" );
00295 };
00296
00297
00298
00303 class ArxDS : public DS {
00304 protected:
00306 RV Rrv;
00308 vec H;
00310 vec U;
00312 vec rgr;
00314 datalink rgrlnk;
00316 mlnorm<chmat> model;
00318 bool opt_L_theta;
00320 int L_theta;
00321 int L_R;
00322 int dt_size;
00323 public:
00324 void getdata ( vec &dt ) {
00325 dt = H;
00326 }
00327
00328 void getdata ( vec &dt, const ivec &indices ) {
00329 dt = H ( indices );
00330 }
00331
00332 void write ( vec &ut ) {
00333 U = ut;
00334 }
00335
00336 void write ( vec &ut, const ivec &indices ) {
00337 bdm_assert_debug ( ut.length() == indices.length(), "ArxDS" );
00338 set_subvector ( U, indices, ut );
00339 }
00340
00341 void step();
00342
00344 ArxDS ( ) {};
00346 void set_parameters ( const mat &Th0, const vec mu0, const chmat &sqR0 ) {
00347 model.set_parameters ( Th0, mu0, sqR0 );
00348 };
00350 void set_drv ( const RV &yrv, const RV &urv, const RV &rrv ) {
00351 Rrv = rrv;
00352 Urv = urv;
00353 dt_size = yrv._dsize() + urv._dsize();
00354
00355 RV drv = concat ( yrv, urv );
00356 Drv = drv;
00357 int td = rrv.mint();
00358 H.set_size ( drv._dsize() * ( -td + 1 ) );
00359 U.set_size ( Urv._dsize() );
00360 for ( int i = -1; i >= td; i-- ) {
00361 drv.t_plus ( -1 );
00362 Drv.add ( drv );
00363 }
00364 rgrlnk.set_connection ( rrv, Drv );
00365
00366 dtsize = Drv._dsize();
00367 utsize = Urv._dsize();
00368 }
00370 void set_options ( const string &s ) {
00371 opt_L_theta = ( s.find ( "L_theta" ) != string::npos );
00372 };
00373 virtual void log_add ( logger &L ) {
00374
00375 L_dt = L.add ( Drv ( 0, dt_size ), "" );
00376 L_ut = L.add ( Urv, "" );
00377
00378 const mat &A = model._A();
00379 const mat R = model._R();
00380 if ( opt_L_theta ) {
00381 L_theta = L.add ( RV ( "{th }", vec_1 ( A.rows() * A.cols() ) ), "t" );
00382 }
00383 if ( opt_L_theta ) {
00384 L_R = L.add ( RV ( "{R }", vec_1 ( R.rows() * R.cols() ) ), "r" );
00385 }
00386 }
00387 virtual void logit ( logger &L ) {
00388
00389 L.logit ( L_dt, H.left ( dt_size ) );
00390 L.logit ( L_ut, U );
00391
00392 const mat &A = model._A();
00393 const mat R = model._R();
00394 if ( opt_L_theta ) {
00395 L.logit ( L_theta, vec ( A._data(), A.rows() *A.cols() ) );
00396 };
00397 if ( opt_L_theta ) {
00398 L.logit ( L_R, vec ( R._data(), R.rows() *R.rows() ) );
00399 };
00400 }
00401
00402
00434 void from_setting ( const Setting &set );
00435
00436
00437 };
00438
00439 UIREGISTER ( ArxDS );
00440 SHAREDPTR ( ArxDS );
00441
00442 class stateDS : public DS {
00443 private:
00445 shared_ptr<mpdf> IM;
00446
00448 shared_ptr<mpdf> OM;
00449
00450 protected:
00452 vec dt;
00454 vec xt;
00456 vec ut;
00458 int L_xt;
00459
00460 public:
00461 void getdata ( vec &dt0 ) {
00462 dt0 = dt;
00463 }
00464
00465 void getdata ( vec &dt0, const ivec &indices ) {
00466 dt0 = dt ( indices );
00467 }
00468
00469 stateDS ( const shared_ptr<mpdf> &IM0, const shared_ptr<mpdf> &OM0, int usize ) : IM ( IM0 ), OM ( OM0 ),
00470 dt ( OM0->dimension() ), xt ( IM0->dimension() ),
00471 ut ( usize ), L_xt ( 0 ) { }
00472
00473 stateDS() : L_xt ( 0 ) { }
00474
00475 virtual void step() {
00476 xt = IM->samplecond ( concat ( xt, ut ) );
00477 dt = OM->samplecond ( concat ( xt, ut ) );
00478 }
00479
00480 virtual void log_add ( logger &L ) {
00481 DS::log_add ( L );
00482 L_xt = L.add ( IM->_rv(), "true" );
00483 }
00484 virtual void logit ( logger &L ) {
00485 DS::logit ( L );
00486 L.logit ( L_xt, xt );
00487 }
00488
00518 void from_setting ( const Setting &set );
00519
00520
00521
00522 };
00523
00524 UIREGISTER ( stateDS );
00525 SHAREDPTR ( stateDS );
00526
00527 };
00528
00529 #endif // DS_H