root/library/bdm/base/bdmbase.cpp @ 896

Revision 896, 15.4 kB (checked in by mido, 14 years ago)

cleanup of MemDS and its descendants
bdmtoolbox/CMakeLists.txt slightly changed to avoid unnecessary MEX condition
"indeces" replaced by "indices"

  • Property svn:eol-style set to native
RevLine 
[262]1
[384]2#include "bdmbase.h"
[2]3
[262]4//! Space of basic BDM structures
[254]5namespace bdm {
[271]6
[600]7const int RV::BUFFER_STEP = 1;
[2]8
[600]9Array<string> RV::NAMES ( RV::BUFFER_STEP );
[211]10
[600]11ivec RV::SIZES ( RV::BUFFER_STEP );
12
13RV::str2int_map RV::MAP;
14
[477]15void RV::clear_all() {
[600]16        MAP.clear();
17        SIZES.clear();
18        NAMES = Array<string> ( BUFFER_STEP );
[436]19}
[477]20
[737]21string RV::show_all() {
22        ostringstream os;
23        for ( str2int_map::const_iterator iter = MAP.begin(); iter != MAP.end(); iter++ ) {
24                os << "key: " << iter->first << " val: " << iter->second << endl;
25        }
26        return os.str();
[604]27};
28
[766]29int RV::assign_id( const string &name, int size ) {
[162]30        //Refer
[270]31        int id;
[600]32        str2int_map::const_iterator iter = MAP.find ( name );
[737]33        if ( iter == MAP.end() || name.length() == 0 ) { //add new RV
[600]34                id = MAP.size() + 1;
[280]35                //debug
[477]36                /*              {
37                                        cout << endl;
[600]38                                        str2int_map::const_iterator iter = MAP.begin();
39                                        for(str2int_map::const_iterator iter=MAP.begin(); iter!=MAP.end(); iter++){
[477]40                                                cout << "key: " << iter->first << " val: " << iter->second <<endl;
41                                        }
42                                }*/
43
[600]44                MAP.insert ( make_pair ( name, id ) ); //add new rv
45                if ( id >= NAMES.length() ) {
46                        NAMES.set_length ( id + BUFFER_STEP, true );
47                        SIZES.set_length ( id + BUFFER_STEP, true );
[270]48                }
[600]49                NAMES ( id ) = name;
50                SIZES ( id ) = size;
[737]51                bdm_assert ( size > 0, "RV " + name + " does not exists. Default size (-1) can not be assigned " );
[477]52        } else {
[270]53                id = iter->second;
[737]54                if ( size > 0 && name.length() > 0 ) {
55                        bdm_assert ( SIZES ( id ) == size, "RV " + name + " of size " + num2str ( SIZES ( id ) ) + " exists, requested size " + num2str ( size ) + "can not be assigned" );
[624]56                }
[270]57        }
58        return id;
[2]59};
60
[477]61int RV::countsize() const {
62        int tmp = 0;
63        for ( int i = 0; i < len; i++ ) {
[600]64                tmp += SIZES ( ids ( i ) );
[477]65        }
66        return tmp;
67}
[2]68
[271]69ivec RV::cumsizes() const {
70        ivec szs ( len );
[477]71        int tmp = 0;
72        for ( int i = 0; i < len; i++ ) {
[600]73                tmp += SIZES ( ids ( i ) );
[271]74                szs ( i ) = tmp;
75        }
76        return szs;
77}
78
[477]79void RV::init ( const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times ) {
[270]80        len = in_names.length();
[620]81        bdm_assert ( in_names.length() == in_times.length(), "check \"times\" " );
82        bdm_assert ( in_names.length() == in_sizes.length(), "check \"sizes\" " );
[271]83
[270]84        times.set_length ( len );
85        ids.set_length ( len );
86        int id;
[477]87        for ( int i = 0; i < len; i++ ) {
[766]88                id = assign_id ( in_names ( i ), in_sizes ( i ) );
[270]89                ids ( i ) = id;
90        }
91        times = in_times;
92        dsize = countsize();
[162]93}
[8]94
[270]95RV::RV ( string name, int sz, int tm ) {
[477]96        Array<string> A ( 1 );
97        A ( 0 ) = name;
98        init ( A, vec_1 ( sz ), vec_1 ( tm ) );
[162]99}
100
[145]101bool RV::add ( const RV &rv2 ) {
[477]102        if ( rv2.len > 0 ) { //rv2 is nonempty
[162]103                ivec ind = rv2.findself ( *this ); //should be -1 all the time
[477]104                ivec index = itpp::find ( ind == -1 );
[162]105
106                if ( index.length() < rv2.len ) { //conflict
107                        ids = concat ( ids, rv2.ids ( index ) );
108                        times = concat ( times, rv2.times ( index ) );
[477]109                } else {
[162]110                        ids = concat ( ids, rv2.ids );
111                        times = concat ( times, rv2.times );
112                }
113                len = ids.length();
[270]114                dsize = countsize();
[477]115                return ( index.length() == rv2.len ); //conflict or not
116        } else { //rv2 is empty
[270]117                return true; // no conflict
[145]118        }
[32]119};
120
[270]121RV RV::subselect ( const ivec &ind ) const {
[162]122        RV ret;
[271]123        ret.ids = ids ( ind );
[477]124        ret.times = times ( ind );
[270]125        ret.len = ind.length();
[477]126        ret.dsize = ret.countsize();
[162]127        return ret;
[5]128}
129
[477]130RV RV::operator() ( int di1, int di2 ) const {
131        ivec sz = cumsizes();
132        int i1 = 0;
133        while ( sz ( i1 ) < di1 ) i1++;
134        int i2 = i1;
135        while ( sz ( i2 ) < di2 ) i2++;
136        return subselect ( linspace ( i1, i2 ) );
[422]137}
138
[604]139void RV::t_plus ( int delta ) {
[477]140        times += delta;
141}
[8]142
[145]143bool RV::equal ( const RV &rv2 ) const {
[270]144        return ( ids == rv2.ids ) && ( times == rv2.times );
[102]145}
[5]146
[766]147shared_ptr<pdf> epdf::condition ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<pdf>() );
[504]148
149
[766]150shared_ptr<epdf> epdf::marginal ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<epdf>() );
151
[713]152mat epdf::sample_mat ( int N ) const {
[270]153        mat X = zeros ( dim, N );
[477]154        for ( int i = 0; i < N; i++ ) X.set_col ( i, this->sample() );
[102]155        return X;
[461]156}
[102]157
[713]158vec epdf::evallog_mat ( const mat &Val ) const {
[600]159        vec x ( Val.cols() );
160        for ( int i = 0; i < Val.cols(); i++ ) {
161                x ( i ) = evallog ( Val.get_col ( i ) );
[502]162        }
[102]163
[502]164        return x;
165}
166
[713]167vec epdf::evallog_mat ( const Array<vec> &Avec ) const {
[600]168        vec x ( Avec.size() );
169        for ( int i = 0; i < Avec.size(); i++ ) {
170                x ( i ) = evallog ( Avec ( i ) );
[502]171        }
172
173        return x;
174}
175
[713]176mat pdf::samplecond_mat ( const vec &cond, int N ) {
[600]177        mat M ( dimension(), N );
[532]178        for ( int i = 0; i < N; i++ ) {
[600]179                M.set_col ( i, samplecond ( cond ) );
[532]180        }
181
182        return M;
183}
184
[693]185void pdf::from_setting ( const Setting &set ) {
[746]186        root::from_setting( set );
[527]187        shared_ptr<RV> r = UI::build<RV> ( set, "rv", UI::optional );
[477]188        if ( r ) {
189                set_rv ( *r );
190        }
[461]191
[527]192        r = UI::build<RV> ( set, "rvc", UI::optional );
[477]193        if ( r ) {
194                set_rvc ( *r );
195        }
[461]196}
197
[746]198void pdf::to_setting ( Setting &set ) const {   
199        root::to_setting( set );
200        UI::save( &rv, set, "rv" );
201        UI::save( &rvc, set, "rvc" );
202}
203
[600]204void datalink::set_connection ( const RV &rv, const RV &rv_up ) {
[545]205        downsize = rv._dsize();
[600]206        upsize = rv_up._dsize();
207        v2v_up = rv.dataind ( rv_up );
208        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
[545]209}
210
[600]211void datalink::set_connection ( int ds, int us, const ivec &upind ) {
[545]212        downsize = ds;
213        upsize = us;
214        v2v_up = upind;
[600]215        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
[598]216}
[545]217
[600]218void datalink_part::set_connection ( const RV &rv, const RV &rv_up ) {
219        rv.dataind ( rv_up, v2v_down, v2v_up );
220        downsize = v2v_down.length();
221        upsize = v2v_up.length();
[545]222}
223
[600]224void datalink_m2e::set_connection ( const RV &rv, const RV &rvc, const RV &rv_up ) {
225        datalink::set_connection ( rv, rv_up );
[545]226        condsize = rvc._dsize();
227        //establish v2c connection
[600]228        rvc.dataind ( rv_up, v2c_lo, v2c_up );
[545]229}
230
[600]231vec datalink_m2e::get_cond ( const vec &val_up ) {
232        vec tmp ( condsize );
233        set_subvector ( tmp, v2c_lo, val_up ( v2c_up ) );
[545]234        return tmp;
235}
236
[600]237void datalink_m2e::pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
238        bdm_assert_debug ( downsize == val.length(), "Wrong val" );
239        bdm_assert_debug ( upsize == val_up.length(), "Wrong val_up" );
240        set_subvector ( val_up, v2v_up, val );
241        set_subvector ( val_up, v2c_up, cond );
[545]242}
243
[102]244std::ostream &operator<< ( std::ostream &os, const RV &rv ) {
[270]245        int id;
[477]246        for ( int i = 0; i < rv.len ; i++ ) {
247                id = rv.ids ( i );
[600]248                os << id << "(" << RV::SIZES ( id ) << ")" <<  // id(size)=
249                "=" << RV::NAMES ( id )  << "_{"  << rv.times ( i ) << "}; "; //name_{time}
[5]250        }
251        return os;
252}
253
[738]254RV RV::expand_delayes() const {
255        RV rvt = this->remove_time(); //rv at t=0
256        RV tmp = rvt;
257        int td = mint();
258        for ( int i = -1; i >= td; i-- ) {
259                rvt.t_plus ( -1 );
260                tmp.add ( rvt ); //shift u1
261        }
262        return tmp;
263}
264
[145]265str RV::tostr() const {
[270]266        ivec idlist ( dsize );
267        ivec tmlist ( dsize );
[19]268        int i;
269        int pos = 0;
[477]270        for ( i = 0; i < len; i++ ) {
[280]271                idlist.set_subvector ( pos, pos + size ( i ) - 1, ids ( i ) );
272                tmlist.set_subvector ( pos, pos + size ( i ) - 1, times ( i ) );
273                pos += size ( i );
[19]274        }
[145]275        return str ( idlist, tmlist );
[19]276}
[32]277
[270]278ivec RV::dataind ( const RV &rv2 ) const {
[145]279        ivec res ( 0 );
[477]280        if ( rv2._dsize() > 0 ) {
[165]281                str str2 = rv2.tostr();
282                ivec part;
283                int i;
[477]284                for ( i = 0; i < len; i++ ) {
[165]285                        part = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
286                        res = concat ( res, part );
287                }
[145]288        }
[565]289
[842]290        //bdm_assert_debug ( res.length() == dsize, "this rv is not fully present in crv!" );
[145]291        return res;
[165]292
[145]293}
294
[270]295void RV::dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const {
[182]296        //clean results
[270]297        selfi.set_size ( 0 );
298        rv2i.set_size ( 0 );
299
[182]300        // just in case any rv is empty
[477]301        if ( ( len == 0 ) || ( rv2.length() == 0 ) ) {
302                return;
303        }
[270]304
[182]305        //find comon rv
[477]306        ivec cids = itpp::find ( this->findself ( rv2 ) >= 0 );
[270]307
308        // index of
[477]309        if ( cids.length() > 0 ) {
[182]310                str str1 = tostr();
[270]311                str str2 = rv2.tostr();
312
[182]313                ivec part1;
314                ivec part2;
[477]315                int i, j;
[182]316                // find common rv in strs
[477]317                for ( j = 0; j < cids.length(); j++ ) {
[270]318                        i = cids ( j );
[182]319                        part1 = itpp::find ( ( str1.ids == ids ( i ) ) & ( str1.times == times ( i ) ) );
320                        part2 = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
321                        selfi = concat ( selfi, part1 );
322                        rv2i = concat ( rv2i, part2 );
323                }
324        }
[565]325        bdm_assert_debug ( selfi.length() == rv2i.length(), "this should not happen!" );
[182]326}
327
[176]328RV RV::subt ( const RV &rv2 ) const {
[145]329        ivec res = this->findself ( rv2 ); // nonzeros
[178]330        ivec valid;
[477]331        if ( dsize > 0 ) {
332                valid = itpp::find ( res == -1 );    //-1 => value not found => it remains
333        }
[145]334        return ( *this ) ( valid ); //keep those that were not found in rv2
335}
336
[738]337std::string RV::scalarname ( int scalat ) const {
338        bdm_assert ( scalat < dsize, "Wrong input index" );
339        int id = 0;
340        int scalid = 0;
341        while ( scalid + SIZES ( ids ( id ) ) <= scalat )  {
342                scalid += SIZES ( ids ( id ) );
343                id++;
344        };
345        //now id is the id of variable of interest
346        if ( size ( id ) == 1 )
347                return  NAMES ( ids ( id ) );
348        else
349                return  NAMES ( ids ( id ) ) + "_" + num2str ( scalat - scalid );
350
351}
352
[145]353ivec RV::findself ( const RV &rv2 ) const {
354        int i, j;
355        ivec tmp = -ones_i ( len );
[477]356        for ( i = 0; i < len; i++ ) {
357                for ( j = 0; j < rv2.length(); j++ ) {
[145]358                        if ( ( ids ( i ) == rv2.ids ( j ) ) & ( times ( i ) == rv2.times ( j ) ) ) {
359                                tmp ( i ) = j;
360                                break;
361                        }
362                }
363        }
364        return tmp;
365}
366
[598]367ivec RV::findself_ids ( const RV &rv2 ) const {
368        int i, j;
369        ivec tmp = -ones_i ( len );
370        for ( i = 0; i < len; i++ ) {
371                for ( j = 0; j < rv2.length(); j++ ) {
372                        if ( ( ids ( i ) == rv2.ids ( j ) ) ) {
373                                tmp ( i ) = j;
374                                break;
375                        }
376                }
377        }
378        return tmp;
379}
380
[477]381void RV::from_setting ( const Setting &set ) {
[357]382        Array<string> A;
[479]383        UI::get ( A, set, "names" );
[477]384
[357]385        ivec szs;
[477]386        if ( !UI::get ( szs, set, "sizes" ) )
387                szs = ones_i ( A.length() );
388
[357]389        ivec tms;
[477]390        if ( !UI::get ( tms, set, "times" ) )
391                tms = zeros_i ( A.length() );
392
393        init ( A, szs, tms );
[357]394}
395
[733]396void RV::to_setting ( Setting &set ) const {
[737]397        Array<string> names ( len );
398        ivec sizes ( len );
399        for ( int i = 0; i < len; i++ ) {
400                names ( i ) = name ( i );
401                sizes ( i ) = size ( i );
402        }
403        UI::save ( names, set, "names" );
404        UI::save ( sizes, set, "sizes" );
405        UI::save ( times, set, "times" );
[733]406}
407
[102]408RV concat ( const RV &rv1, const RV &rv2 ) {
[32]409        RV pom = rv1;
[102]410        pom.add ( rv2 );
[32]411        return pom;
412}
[170]413
[693]414RV get_composite_rv ( const Array<shared_ptr<pdf> > &pdfs,
[600]415                      bool checkoverlap ) {
[175]416        RV rv; //empty rv
417        bool rvaddok;
[693]418        for ( int i = 0; i < pdfs.length(); i++ ) {
[770]419                bdm_assert( pdfs(i)->isnamed(), "Can not extract RV from pdf no. " + num2str(i));
[693]420                rvaddok = rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
421                // If rvaddok==false, pdfs overlap => assert error.
422                bdm_assert_debug ( rvaddok || !checkoverlap, "mprod::mprod() input pdfs overlap in rv!" );
[507]423        }
424
[175]425        return rv;
426}
427
[889]428int logger::add_vector ( const RV &rv, const string &prefix, const string &name ) {
[738]429        int id;
430        if ( rv._dsize() > 0 ) {
431                id = entries.length();
432                names = concat ( names, prefix ); // diff
[889]433                if( name.length() > 0 )
434                {
435                        concat( names, separator );
436                        concat( names, name );
437                }
[738]438                entries.set_length ( id + 1, true );
439                entries ( id ) = rv;
440        } else {
441                id = -1;
442        }
443        return id; // identifier of the last entry
444}
445
446int logger::add_setting ( const string &prefix ) {
447        Setting &root = setting_conf.getRoot();
448        int id = root.getLength(); //root must be group!!
449        if ( prefix.length() > 0 ) {
450                settings.set_length ( id + 1, true );
451                settings ( id ) = &root.add ( prefix, Setting::TypeList );
452        } else {
453                id = -1;
454        }
455        return id;
456}
457
458void epdf::log_register ( logger &L, const string &prefix ) {
459        RV r;
460        if ( isnamed() ) {
461                r = _rv();
462        } else {
463                r = RV ( "", dimension() );
464        };
465        root::log_register ( L, prefix );
466
467        // log full data
[863]468        if ( log_level[logfull] ) {
[738]469                logrec->ids.set_size ( 1 );
470                logrec->ids ( 0 ) = logrec->L.add_setting ( prefix );
471        } else {
472                // log only
473                logrec->ids.set_size ( 3 );
[863]474                if ( log_level[logmean] ) {
[889]475                        logrec->ids ( 0 ) = logrec->L.add_vector ( r, prefix, "mean" );
[738]476                }
[871]477                if ( log_level[loglbound]  ) {
[889]478                        logrec->ids ( 1 ) = logrec->L.add_vector ( r, prefix, "lb" );
[863]479                }       
[871]480                if ( log_level[logubound]  ) {
[889]481                        logrec->ids ( 2 ) = logrec->L.add_vector ( r, prefix, "ub" );
[863]482                }
483       
[738]484        }
485}
486
487void epdf::log_write() const {
[863]488        if ( log_level[logfull] ) {
[802]489                UI::save(this,  logrec->L.log_to_setting ( logrec->ids ( 0 ) ) );
[738]490        } else {
[863]491                if ( log_level[logmean] ) {
[738]492                        logrec->L.log_vector ( logrec->ids ( 0 ), mean() );
[863]493                }
[871]494                if ( log_level[loglbound] || log_level[logubound] ) {
[850]495                                vec lb;
496                                vec ub;
497                                qbounds ( lb, ub );
[871]498                                if (log_level[loglbound])
[863]499                                        logrec->L.log_vector ( logrec->ids ( 1 ), lb );
[871]500                                if (log_level[logubound])
[863]501                                        logrec->L.log_vector ( logrec->ids ( 2 ), ub );
[850]502                        }
[738]503                }
504        }
505
[863]506
[738]507void datalink_buffered::set_connection ( const RV &rv, const RV &rv_up ) {
508        // create link between up and down
[744]509        datalink_part::set_connection ( rv, rv_up); // only non-delayed version
[738]510
[881]511        RV needed_from_hist = rv.subt(rv_up); //rv_up already copied by v2v
512       
[738]513        // we can store only what we get in rv_up - everything else is removed
[881]514        ivec valid_ids = needed_from_hist.findself_ids ( rv_up ); // return on which position the required id is in rv_up
515        RV rv_hist = needed_from_hist.subselect ( find ( valid_ids >= 0 ) ); // select only rvs that are in rv_up, ie ind>0
[738]516        RV rv_hist0 = rv_hist.remove_time(); // these RVs will form history at time =0
517        // now we need to know what is needed from Up
518        rv_hist = rv_hist.expand_delayes(); // full regressor - including time 0
519        Hrv = rv_hist.subt ( rv_hist0 );   // remove time 0
520        history = zeros ( Hrv._dsize() );
521
522        // decide if we need to copy val to history
523        if ( Hrv._dsize() > 0 ) {
[896]524                v2h_up = rv_hist0.dataind ( rv_up ); // indices of elements of rv_up to be copied
[738]525        } // else v2h_up is empty
526
527        Hrv.dataind ( rv, h2v_hist, h2v_down );
528
529        downsize = v2v_down.length() + h2v_down.length();
530        upsize = v2v_up.length();
531}
532
533void datalink_buffered::set_history ( const RV& rv1, const vec &hist0 ) {
534        bdm_assert ( rv1._dsize() == hist0.length(), "hist is not compatible with given rv1" );
535        ivec ind_H;
536        ivec ind_h0;
[896]537        Hrv.dataind ( rv1, ind_H, ind_h0 ); // find indices of rv in
[738]538        set_subvector ( history, ind_H, hist0 ( ind_h0 ) ); // copy given hist to appropriate places
539}
540
541void DS::log_register ( logger &L,  const string &prefix ) {
[854]542        bdm_assert ( dtsize == Drv._dsize(), "invalid DS: dtsize (" + num2str ( dtsize ) + ") different from Drv " + num2str ( Drv._dsize() ) );
543        //bdm_assert ( utsize == Urv._dsize(), "invalid DS: utsize (" + num2str ( utsize ) + ") different from Urv " + num2str ( Urv._dsize() ) );
[738]544
545        root::log_register ( L, prefix );
546        //we know that
[850]547        if ( log_level.any() ) {
[854]548                logrec->ids.set_size ( 1 );
549                logrec->ids ( 0 ) = logrec->L.add_vector ( Drv, prefix );
550        //      logrec->ids ( 1 ) = logrec->L.add_vector ( Urv, prefix );
[738]551        }
552}
553
554void DS::log_write ( ) const {
[850]555        if ( log_level.any() ) {
[854]556                vec tmp ( Drv._dsize());
[738]557                getdata ( tmp );
558                // d is first in getdata
[854]559                logrec->L.log_vector ( logrec->ids ( 0 ), tmp );
[738]560        }
561}
562
563void BM::log_register ( logger &L, const string &prefix ) {
564        root::log_register ( L, prefix );
565
[850]566        if ( log_level.any() ) {
[738]567                logrec->ids.set_size ( 1 );
[889]568                logrec->ids ( 0) = L.add_vector ( RV ( "", 1 ), prefix, "ll" );
[738]569        }
[863]570       
[870]571        if (log_level[logbounds]){
[871]572                prior().log_level[epdf::loglbound]=true;
573                prior().log_level[epdf::logubound]=true;
[863]574        }
[870]575        if (log_level[logfull]){
[863]576                prior().log_level[epdf::logfull]=true;
577        }
[889]578        const_cast<epdf&> ( posterior() ).log_register ( L, prefix + L.separator + "apost" );
[738]579}
580
581void BM::log_write ( ) const {
582        posterior().log_write();
[850]583        if ( log_level.any() ) {
[863]584                logrec->L.logit ( logrec->ids ( 0 ), ll );
[738]585        }
586}
587
[690]588void BM::bayes_batch ( const mat &Data, const vec &cond ) {
[477]589        for ( int t = 0; t < Data.cols(); t++ ) {
[690]590                bayes ( Data.get_col ( t ), cond );
[477]591        }
[182]592}
[738]593
[690]594void BM::bayes_batch ( const mat &Data, const mat &Cond ) {
595        for ( int t = 0; t < Data.cols(); t++ ) {
[737]596                bayes ( Data.get_col ( t ), Cond.get_col ( t ) );
[690]597        }
[262]598}
[738]599
[690]600}
Note: See TracBrowser for help on using the browser.