00001
00013 #ifndef DATASOURCE_H
00014 #define DATASOURCE_H
00015
00016
00017 #include "../base/bdmbase.h"
00018 #include "../stat/exp_family.h"
00019 #include "../base/user_info.h"
00020
00021 namespace bdm {
00029 class MemDS : public DS {
00030 protected:
00032 mat Data;
00034 int time;
00036 ivec rowid;
00038 ivec delays;
00039
00040 public:
00041 void getdata ( vec &dt );
00042 void getdata ( vec &dt, const ivec &indeces );
00043 void set_rvs ( RV &drv, RV &urv );
00044 void write ( vec &ut ) {
00045 it_error ( "MemDS::write is not supported" );
00046 }
00047 void write ( vec &ut, ivec &indices ) {
00048 it_error ( "MemDS::write is not supported" );
00049 }
00050 void step();
00052 MemDS () {};
00053 MemDS ( mat &Dat, ivec &rowid0, ivec &delays0 );
00054 };
00055
00059 class FileDS: public MemDS {
00060
00061 public:
00062 void getdata ( vec &dt ) {
00063 it_assert_debug ( dt.length() == Data.rows(), "" );
00064 dt = Data.get_col ( time );
00065 };
00066 void getdata ( vec &dt, const ivec &indeces ) {
00067 it_assert_debug ( dt.length() == indeces.length(), "" );
00068 vec tmp ( indeces.length() );
00069 tmp = Data.get_col ( time );
00070 dt = tmp ( indeces );
00071 };
00073 int ndat() {
00074 return Data.cols();
00075 }
00077 void log_add ( logger &L ) {};
00079 void logit ( logger &L ) {};
00080 };
00081
00088 class ITppFileDS: public FileDS {
00089
00090 public:
00091 ITppFileDS ( const string &fname, const string &varname ) : FileDS() {
00092 it_file it ( fname );
00093 it << Name ( varname );
00094 it >> Data;
00095 time = 0;
00096
00097 };
00098
00099 ITppFileDS () : FileDS() {
00100 };
00101
00102 void from_setting ( const Setting &set );
00103
00104
00105
00106 };
00107
00108 UIREGISTER ( ITppFileDS );
00109 SHAREDPTR ( ITppFileDS );
00110
00118 class CsvFileDS: public FileDS {
00119
00120 public:
00122 CsvFileDS ( const string& fname, const string& orientation = "BY_COL" );
00123 };
00124
00125
00126
00131 class ArxDS : public DS {
00132 protected:
00134 RV Rrv;
00136 vec H;
00138 vec U;
00140 vec rgr;
00142 datalink rgrlnk;
00144 mlnorm<chmat> model;
00146 bool opt_L_theta;
00148 int L_theta;
00149 int L_R;
00150 int dt_size;
00151 public:
00152 void getdata ( vec &dt ) {
00153
00154 dt = H;
00155 };
00156 void getdata ( vec &dt, const ivec &indices ) {
00157 it_assert_debug ( dt.length() == indices.length(), "ArxDS" );
00158 dt = H ( indices );
00159 };
00160 void write ( vec &ut ) {
00161
00162 U = ut;
00163 };
00164 void write ( vec &ut, const ivec &indices ) {
00165 it_assert_debug ( ut.length() == indices.length(), "ArxDS" );
00166 set_subvector ( U, indices, ut );
00167 };
00168 void step();
00170 ArxDS ( ) {};
00172 void set_parameters ( const mat &Th0, const vec mu0, const chmat &sqR0 ) {
00173 model.set_parameters ( Th0, mu0, sqR0 );
00174 };
00176 void set_drv ( const RV &yrv, const RV &urv, const RV &rrv ) {
00177 Rrv = rrv;
00178 Urv = urv;
00179 dt_size = yrv._dsize() + urv._dsize();
00180
00181 RV drv = concat ( yrv, urv );
00182 Drv = drv;
00183 int td = rrv.mint();
00184 H.set_size ( drv._dsize() * ( -td + 1 ) );
00185 U.set_size ( Urv._dsize() );
00186 for ( int i = -1; i >= td; i-- ) {
00187 drv.t ( -1 );
00188 Drv.add ( drv );
00189 }
00190 rgrlnk.set_connection ( rrv, Drv );
00191
00192 dtsize = Drv._dsize();
00193 utsize = Urv._dsize();
00194 }
00196 void set_options ( const string &s ) {
00197 opt_L_theta = ( s.find ( "L_theta" ) != string::npos );
00198 };
00199 virtual void log_add ( logger &L ) {
00200
00201 L_dt = L.add ( Drv ( 0, dt_size ), "" );
00202 L_ut = L.add ( Urv, "" );
00203
00204 mat &A = model._A();
00205 mat R = model._R();
00206 if ( opt_L_theta ) {
00207 L_theta = L.add ( RV ( "{th }", vec_1 ( A.rows() * A.cols() ) ), "t" );
00208 }
00209 if ( opt_L_theta ) {
00210 L_R = L.add ( RV ( "{R }", vec_1 ( R.rows() * R.cols() ) ), "r" );
00211 }
00212 }
00213 virtual void logit ( logger &L ) {
00214
00215 L.logit ( L_dt, H.left ( dt_size ) );
00216 L.logit ( L_ut, U );
00217
00218 mat &A = model._A();
00219 mat R = model._R();
00220 if ( opt_L_theta ) {
00221 L.logit ( L_theta, vec ( A._data(), A.rows() *A.cols() ) );
00222 };
00223 if ( opt_L_theta ) {
00224 L.logit ( L_R, vec ( R._data(), R.rows() *R.rows() ) );
00225 };
00226 }
00227
00228
00260 void from_setting ( const Setting &set );
00261
00262
00263 };
00264
00265 UIREGISTER ( ArxDS );
00266 SHAREDPTR ( ArxDS );
00267
00268 class stateDS : public DS {
00269 private:
00271 shared_ptr<mpdf> IM;
00272
00274 shared_ptr<mpdf> OM;
00275
00276 protected:
00278 vec dt;
00280 vec xt;
00282 vec ut;
00284 int L_xt;
00285
00286 public:
00287 void getdata ( vec &dt0 ) {
00288 dt0 = dt;
00289 }
00290 void getdata ( vec &dt0, const ivec &indeces ) {
00291 dt0 = dt ( indeces );
00292 }
00293
00294 stateDS ( const shared_ptr<mpdf> &IM0, const shared_ptr<mpdf> &OM0, int usize ) : IM ( IM0 ), OM ( OM0 ),
00295 dt ( OM0->dimension() ), xt ( IM0->dimension() ),
00296 ut ( usize ), L_xt(0) { }
00297
00298 stateDS() : L_xt(0) { }
00299
00300 virtual void step() {
00301 xt = IM->samplecond ( concat ( xt, ut ) );
00302 dt = OM->samplecond ( concat ( xt, ut ) );
00303 }
00304
00305 virtual void log_add ( logger &L ) {
00306 DS::log_add ( L );
00307 L_xt = L.add ( IM->_rv(), "true" );
00308 }
00309 virtual void logit ( logger &L ) {
00310 DS::logit ( L );
00311 L.logit ( L_xt, xt );
00312 }
00313
00343 void from_setting ( const Setting &set );
00344
00345
00346
00347 };
00348
00349 UIREGISTER ( stateDS );
00350 SHAREDPTR ( stateDS );
00351
00352 };
00353
00354 #endif // DS_H