root/library/bdm/base/bdmbase.cpp @ 896

Revision 896, 15.4 kB (checked in by mido, 14 years ago)

cleanup of MemDS and its descendants
bdmtoolbox/CMakeLists.txt slightly changed to avoid unnecessary MEX condition
"indeces" replaced by "indices"

  • Property svn:eol-style set to native
Line 
1
2#include "bdmbase.h"
3
4//! Space of basic BDM structures
5namespace bdm {
6
7const int RV::BUFFER_STEP = 1;
8
9Array<string> RV::NAMES ( RV::BUFFER_STEP );
10
11ivec RV::SIZES ( RV::BUFFER_STEP );
12
13RV::str2int_map RV::MAP;
14
15void RV::clear_all() {
16        MAP.clear();
17        SIZES.clear();
18        NAMES = Array<string> ( BUFFER_STEP );
19}
20
21string RV::show_all() {
22        ostringstream os;
23        for ( str2int_map::const_iterator iter = MAP.begin(); iter != MAP.end(); iter++ ) {
24                os << "key: " << iter->first << " val: " << iter->second << endl;
25        }
26        return os.str();
27};
28
29int RV::assign_id( const string &name, int size ) {
30        //Refer
31        int id;
32        str2int_map::const_iterator iter = MAP.find ( name );
33        if ( iter == MAP.end() || name.length() == 0 ) { //add new RV
34                id = MAP.size() + 1;
35                //debug
36                /*              {
37                                        cout << endl;
38                                        str2int_map::const_iterator iter = MAP.begin();
39                                        for(str2int_map::const_iterator iter=MAP.begin(); iter!=MAP.end(); iter++){
40                                                cout << "key: " << iter->first << " val: " << iter->second <<endl;
41                                        }
42                                }*/
43
44                MAP.insert ( make_pair ( name, id ) ); //add new rv
45                if ( id >= NAMES.length() ) {
46                        NAMES.set_length ( id + BUFFER_STEP, true );
47                        SIZES.set_length ( id + BUFFER_STEP, true );
48                }
49                NAMES ( id ) = name;
50                SIZES ( id ) = size;
51                bdm_assert ( size > 0, "RV " + name + " does not exists. Default size (-1) can not be assigned " );
52        } else {
53                id = iter->second;
54                if ( size > 0 && name.length() > 0 ) {
55                        bdm_assert ( SIZES ( id ) == size, "RV " + name + " of size " + num2str ( SIZES ( id ) ) + " exists, requested size " + num2str ( size ) + "can not be assigned" );
56                }
57        }
58        return id;
59};
60
61int RV::countsize() const {
62        int tmp = 0;
63        for ( int i = 0; i < len; i++ ) {
64                tmp += SIZES ( ids ( i ) );
65        }
66        return tmp;
67}
68
69ivec RV::cumsizes() const {
70        ivec szs ( len );
71        int tmp = 0;
72        for ( int i = 0; i < len; i++ ) {
73                tmp += SIZES ( ids ( i ) );
74                szs ( i ) = tmp;
75        }
76        return szs;
77}
78
79void RV::init ( const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times ) {
80        len = in_names.length();
81        bdm_assert ( in_names.length() == in_times.length(), "check \"times\" " );
82        bdm_assert ( in_names.length() == in_sizes.length(), "check \"sizes\" " );
83
84        times.set_length ( len );
85        ids.set_length ( len );
86        int id;
87        for ( int i = 0; i < len; i++ ) {
88                id = assign_id ( in_names ( i ), in_sizes ( i ) );
89                ids ( i ) = id;
90        }
91        times = in_times;
92        dsize = countsize();
93}
94
95RV::RV ( string name, int sz, int tm ) {
96        Array<string> A ( 1 );
97        A ( 0 ) = name;
98        init ( A, vec_1 ( sz ), vec_1 ( tm ) );
99}
100
101bool RV::add ( const RV &rv2 ) {
102        if ( rv2.len > 0 ) { //rv2 is nonempty
103                ivec ind = rv2.findself ( *this ); //should be -1 all the time
104                ivec index = itpp::find ( ind == -1 );
105
106                if ( index.length() < rv2.len ) { //conflict
107                        ids = concat ( ids, rv2.ids ( index ) );
108                        times = concat ( times, rv2.times ( index ) );
109                } else {
110                        ids = concat ( ids, rv2.ids );
111                        times = concat ( times, rv2.times );
112                }
113                len = ids.length();
114                dsize = countsize();
115                return ( index.length() == rv2.len ); //conflict or not
116        } else { //rv2 is empty
117                return true; // no conflict
118        }
119};
120
121RV RV::subselect ( const ivec &ind ) const {
122        RV ret;
123        ret.ids = ids ( ind );
124        ret.times = times ( ind );
125        ret.len = ind.length();
126        ret.dsize = ret.countsize();
127        return ret;
128}
129
130RV RV::operator() ( int di1, int di2 ) const {
131        ivec sz = cumsizes();
132        int i1 = 0;
133        while ( sz ( i1 ) < di1 ) i1++;
134        int i2 = i1;
135        while ( sz ( i2 ) < di2 ) i2++;
136        return subselect ( linspace ( i1, i2 ) );
137}
138
139void RV::t_plus ( int delta ) {
140        times += delta;
141}
142
143bool RV::equal ( const RV &rv2 ) const {
144        return ( ids == rv2.ids ) && ( times == rv2.times );
145}
146
147shared_ptr<pdf> epdf::condition ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<pdf>() );
148
149
150shared_ptr<epdf> epdf::marginal ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<epdf>() );
151
152mat epdf::sample_mat ( int N ) const {
153        mat X = zeros ( dim, N );
154        for ( int i = 0; i < N; i++ ) X.set_col ( i, this->sample() );
155        return X;
156}
157
158vec epdf::evallog_mat ( const mat &Val ) const {
159        vec x ( Val.cols() );
160        for ( int i = 0; i < Val.cols(); i++ ) {
161                x ( i ) = evallog ( Val.get_col ( i ) );
162        }
163
164        return x;
165}
166
167vec epdf::evallog_mat ( const Array<vec> &Avec ) const {
168        vec x ( Avec.size() );
169        for ( int i = 0; i < Avec.size(); i++ ) {
170                x ( i ) = evallog ( Avec ( i ) );
171        }
172
173        return x;
174}
175
176mat pdf::samplecond_mat ( const vec &cond, int N ) {
177        mat M ( dimension(), N );
178        for ( int i = 0; i < N; i++ ) {
179                M.set_col ( i, samplecond ( cond ) );
180        }
181
182        return M;
183}
184
185void pdf::from_setting ( const Setting &set ) {
186        root::from_setting( set );
187        shared_ptr<RV> r = UI::build<RV> ( set, "rv", UI::optional );
188        if ( r ) {
189                set_rv ( *r );
190        }
191
192        r = UI::build<RV> ( set, "rvc", UI::optional );
193        if ( r ) {
194                set_rvc ( *r );
195        }
196}
197
198void pdf::to_setting ( Setting &set ) const {   
199        root::to_setting( set );
200        UI::save( &rv, set, "rv" );
201        UI::save( &rvc, set, "rvc" );
202}
203
204void datalink::set_connection ( const RV &rv, const RV &rv_up ) {
205        downsize = rv._dsize();
206        upsize = rv_up._dsize();
207        v2v_up = rv.dataind ( rv_up );
208        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
209}
210
211void datalink::set_connection ( int ds, int us, const ivec &upind ) {
212        downsize = ds;
213        upsize = us;
214        v2v_up = upind;
215        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
216}
217
218void datalink_part::set_connection ( const RV &rv, const RV &rv_up ) {
219        rv.dataind ( rv_up, v2v_down, v2v_up );
220        downsize = v2v_down.length();
221        upsize = v2v_up.length();
222}
223
224void datalink_m2e::set_connection ( const RV &rv, const RV &rvc, const RV &rv_up ) {
225        datalink::set_connection ( rv, rv_up );
226        condsize = rvc._dsize();
227        //establish v2c connection
228        rvc.dataind ( rv_up, v2c_lo, v2c_up );
229}
230
231vec datalink_m2e::get_cond ( const vec &val_up ) {
232        vec tmp ( condsize );
233        set_subvector ( tmp, v2c_lo, val_up ( v2c_up ) );
234        return tmp;
235}
236
237void datalink_m2e::pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
238        bdm_assert_debug ( downsize == val.length(), "Wrong val" );
239        bdm_assert_debug ( upsize == val_up.length(), "Wrong val_up" );
240        set_subvector ( val_up, v2v_up, val );
241        set_subvector ( val_up, v2c_up, cond );
242}
243
244std::ostream &operator<< ( std::ostream &os, const RV &rv ) {
245        int id;
246        for ( int i = 0; i < rv.len ; i++ ) {
247                id = rv.ids ( i );
248                os << id << "(" << RV::SIZES ( id ) << ")" <<  // id(size)=
249                "=" << RV::NAMES ( id )  << "_{"  << rv.times ( i ) << "}; "; //name_{time}
250        }
251        return os;
252}
253
254RV RV::expand_delayes() const {
255        RV rvt = this->remove_time(); //rv at t=0
256        RV tmp = rvt;
257        int td = mint();
258        for ( int i = -1; i >= td; i-- ) {
259                rvt.t_plus ( -1 );
260                tmp.add ( rvt ); //shift u1
261        }
262        return tmp;
263}
264
265str RV::tostr() const {
266        ivec idlist ( dsize );
267        ivec tmlist ( dsize );
268        int i;
269        int pos = 0;
270        for ( i = 0; i < len; i++ ) {
271                idlist.set_subvector ( pos, pos + size ( i ) - 1, ids ( i ) );
272                tmlist.set_subvector ( pos, pos + size ( i ) - 1, times ( i ) );
273                pos += size ( i );
274        }
275        return str ( idlist, tmlist );
276}
277
278ivec RV::dataind ( const RV &rv2 ) const {
279        ivec res ( 0 );
280        if ( rv2._dsize() > 0 ) {
281                str str2 = rv2.tostr();
282                ivec part;
283                int i;
284                for ( i = 0; i < len; i++ ) {
285                        part = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
286                        res = concat ( res, part );
287                }
288        }
289
290        //bdm_assert_debug ( res.length() == dsize, "this rv is not fully present in crv!" );
291        return res;
292
293}
294
295void RV::dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const {
296        //clean results
297        selfi.set_size ( 0 );
298        rv2i.set_size ( 0 );
299
300        // just in case any rv is empty
301        if ( ( len == 0 ) || ( rv2.length() == 0 ) ) {
302                return;
303        }
304
305        //find comon rv
306        ivec cids = itpp::find ( this->findself ( rv2 ) >= 0 );
307
308        // index of
309        if ( cids.length() > 0 ) {
310                str str1 = tostr();
311                str str2 = rv2.tostr();
312
313                ivec part1;
314                ivec part2;
315                int i, j;
316                // find common rv in strs
317                for ( j = 0; j < cids.length(); j++ ) {
318                        i = cids ( j );
319                        part1 = itpp::find ( ( str1.ids == ids ( i ) ) & ( str1.times == times ( i ) ) );
320                        part2 = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
321                        selfi = concat ( selfi, part1 );
322                        rv2i = concat ( rv2i, part2 );
323                }
324        }
325        bdm_assert_debug ( selfi.length() == rv2i.length(), "this should not happen!" );
326}
327
328RV RV::subt ( const RV &rv2 ) const {
329        ivec res = this->findself ( rv2 ); // nonzeros
330        ivec valid;
331        if ( dsize > 0 ) {
332                valid = itpp::find ( res == -1 );    //-1 => value not found => it remains
333        }
334        return ( *this ) ( valid ); //keep those that were not found in rv2
335}
336
337std::string RV::scalarname ( int scalat ) const {
338        bdm_assert ( scalat < dsize, "Wrong input index" );
339        int id = 0;
340        int scalid = 0;
341        while ( scalid + SIZES ( ids ( id ) ) <= scalat )  {
342                scalid += SIZES ( ids ( id ) );
343                id++;
344        };
345        //now id is the id of variable of interest
346        if ( size ( id ) == 1 )
347                return  NAMES ( ids ( id ) );
348        else
349                return  NAMES ( ids ( id ) ) + "_" + num2str ( scalat - scalid );
350
351}
352
353ivec RV::findself ( const RV &rv2 ) const {
354        int i, j;
355        ivec tmp = -ones_i ( len );
356        for ( i = 0; i < len; i++ ) {
357                for ( j = 0; j < rv2.length(); j++ ) {
358                        if ( ( ids ( i ) == rv2.ids ( j ) ) & ( times ( i ) == rv2.times ( j ) ) ) {
359                                tmp ( i ) = j;
360                                break;
361                        }
362                }
363        }
364        return tmp;
365}
366
367ivec RV::findself_ids ( const RV &rv2 ) const {
368        int i, j;
369        ivec tmp = -ones_i ( len );
370        for ( i = 0; i < len; i++ ) {
371                for ( j = 0; j < rv2.length(); j++ ) {
372                        if ( ( ids ( i ) == rv2.ids ( j ) ) ) {
373                                tmp ( i ) = j;
374                                break;
375                        }
376                }
377        }
378        return tmp;
379}
380
381void RV::from_setting ( const Setting &set ) {
382        Array<string> A;
383        UI::get ( A, set, "names" );
384
385        ivec szs;
386        if ( !UI::get ( szs, set, "sizes" ) )
387                szs = ones_i ( A.length() );
388
389        ivec tms;
390        if ( !UI::get ( tms, set, "times" ) )
391                tms = zeros_i ( A.length() );
392
393        init ( A, szs, tms );
394}
395
396void RV::to_setting ( Setting &set ) const {
397        Array<string> names ( len );
398        ivec sizes ( len );
399        for ( int i = 0; i < len; i++ ) {
400                names ( i ) = name ( i );
401                sizes ( i ) = size ( i );
402        }
403        UI::save ( names, set, "names" );
404        UI::save ( sizes, set, "sizes" );
405        UI::save ( times, set, "times" );
406}
407
408RV concat ( const RV &rv1, const RV &rv2 ) {
409        RV pom = rv1;
410        pom.add ( rv2 );
411        return pom;
412}
413
414RV get_composite_rv ( const Array<shared_ptr<pdf> > &pdfs,
415                      bool checkoverlap ) {
416        RV rv; //empty rv
417        bool rvaddok;
418        for ( int i = 0; i < pdfs.length(); i++ ) {
419                bdm_assert( pdfs(i)->isnamed(), "Can not extract RV from pdf no. " + num2str(i));
420                rvaddok = rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
421                // If rvaddok==false, pdfs overlap => assert error.
422                bdm_assert_debug ( rvaddok || !checkoverlap, "mprod::mprod() input pdfs overlap in rv!" );
423        }
424
425        return rv;
426}
427
428int logger::add_vector ( const RV &rv, const string &prefix, const string &name ) {
429        int id;
430        if ( rv._dsize() > 0 ) {
431                id = entries.length();
432                names = concat ( names, prefix ); // diff
433                if( name.length() > 0 )
434                {
435                        concat( names, separator );
436                        concat( names, name );
437                }
438                entries.set_length ( id + 1, true );
439                entries ( id ) = rv;
440        } else {
441                id = -1;
442        }
443        return id; // identifier of the last entry
444}
445
446int logger::add_setting ( const string &prefix ) {
447        Setting &root = setting_conf.getRoot();
448        int id = root.getLength(); //root must be group!!
449        if ( prefix.length() > 0 ) {
450                settings.set_length ( id + 1, true );
451                settings ( id ) = &root.add ( prefix, Setting::TypeList );
452        } else {
453                id = -1;
454        }
455        return id;
456}
457
458void epdf::log_register ( logger &L, const string &prefix ) {
459        RV r;
460        if ( isnamed() ) {
461                r = _rv();
462        } else {
463                r = RV ( "", dimension() );
464        };
465        root::log_register ( L, prefix );
466
467        // log full data
468        if ( log_level[logfull] ) {
469                logrec->ids.set_size ( 1 );
470                logrec->ids ( 0 ) = logrec->L.add_setting ( prefix );
471        } else {
472                // log only
473                logrec->ids.set_size ( 3 );
474                if ( log_level[logmean] ) {
475                        logrec->ids ( 0 ) = logrec->L.add_vector ( r, prefix, "mean" );
476                }
477                if ( log_level[loglbound]  ) {
478                        logrec->ids ( 1 ) = logrec->L.add_vector ( r, prefix, "lb" );
479                }       
480                if ( log_level[logubound]  ) {
481                        logrec->ids ( 2 ) = logrec->L.add_vector ( r, prefix, "ub" );
482                }
483       
484        }
485}
486
487void epdf::log_write() const {
488        if ( log_level[logfull] ) {
489                UI::save(this,  logrec->L.log_to_setting ( logrec->ids ( 0 ) ) );
490        } else {
491                if ( log_level[logmean] ) {
492                        logrec->L.log_vector ( logrec->ids ( 0 ), mean() );
493                }
494                if ( log_level[loglbound] || log_level[logubound] ) {
495                                vec lb;
496                                vec ub;
497                                qbounds ( lb, ub );
498                                if (log_level[loglbound])
499                                        logrec->L.log_vector ( logrec->ids ( 1 ), lb );
500                                if (log_level[logubound])
501                                        logrec->L.log_vector ( logrec->ids ( 2 ), ub );
502                        }
503                }
504        }
505
506
507void datalink_buffered::set_connection ( const RV &rv, const RV &rv_up ) {
508        // create link between up and down
509        datalink_part::set_connection ( rv, rv_up); // only non-delayed version
510
511        RV needed_from_hist = rv.subt(rv_up); //rv_up already copied by v2v
512       
513        // we can store only what we get in rv_up - everything else is removed
514        ivec valid_ids = needed_from_hist.findself_ids ( rv_up ); // return on which position the required id is in rv_up
515        RV rv_hist = needed_from_hist.subselect ( find ( valid_ids >= 0 ) ); // select only rvs that are in rv_up, ie ind>0
516        RV rv_hist0 = rv_hist.remove_time(); // these RVs will form history at time =0
517        // now we need to know what is needed from Up
518        rv_hist = rv_hist.expand_delayes(); // full regressor - including time 0
519        Hrv = rv_hist.subt ( rv_hist0 );   // remove time 0
520        history = zeros ( Hrv._dsize() );
521
522        // decide if we need to copy val to history
523        if ( Hrv._dsize() > 0 ) {
524                v2h_up = rv_hist0.dataind ( rv_up ); // indices of elements of rv_up to be copied
525        } // else v2h_up is empty
526
527        Hrv.dataind ( rv, h2v_hist, h2v_down );
528
529        downsize = v2v_down.length() + h2v_down.length();
530        upsize = v2v_up.length();
531}
532
533void datalink_buffered::set_history ( const RV& rv1, const vec &hist0 ) {
534        bdm_assert ( rv1._dsize() == hist0.length(), "hist is not compatible with given rv1" );
535        ivec ind_H;
536        ivec ind_h0;
537        Hrv.dataind ( rv1, ind_H, ind_h0 ); // find indices of rv in
538        set_subvector ( history, ind_H, hist0 ( ind_h0 ) ); // copy given hist to appropriate places
539}
540
541void DS::log_register ( logger &L,  const string &prefix ) {
542        bdm_assert ( dtsize == Drv._dsize(), "invalid DS: dtsize (" + num2str ( dtsize ) + ") different from Drv " + num2str ( Drv._dsize() ) );
543        //bdm_assert ( utsize == Urv._dsize(), "invalid DS: utsize (" + num2str ( utsize ) + ") different from Urv " + num2str ( Urv._dsize() ) );
544
545        root::log_register ( L, prefix );
546        //we know that
547        if ( log_level.any() ) {
548                logrec->ids.set_size ( 1 );
549                logrec->ids ( 0 ) = logrec->L.add_vector ( Drv, prefix );
550        //      logrec->ids ( 1 ) = logrec->L.add_vector ( Urv, prefix );
551        }
552}
553
554void DS::log_write ( ) const {
555        if ( log_level.any() ) {
556                vec tmp ( Drv._dsize());
557                getdata ( tmp );
558                // d is first in getdata
559                logrec->L.log_vector ( logrec->ids ( 0 ), tmp );
560        }
561}
562
563void BM::log_register ( logger &L, const string &prefix ) {
564        root::log_register ( L, prefix );
565
566        if ( log_level.any() ) {
567                logrec->ids.set_size ( 1 );
568                logrec->ids ( 0) = L.add_vector ( RV ( "", 1 ), prefix, "ll" );
569        }
570       
571        if (log_level[logbounds]){
572                prior().log_level[epdf::loglbound]=true;
573                prior().log_level[epdf::logubound]=true;
574        }
575        if (log_level[logfull]){
576                prior().log_level[epdf::logfull]=true;
577        }
578        const_cast<epdf&> ( posterior() ).log_register ( L, prefix + L.separator + "apost" );
579}
580
581void BM::log_write ( ) const {
582        posterior().log_write();
583        if ( log_level.any() ) {
584                logrec->L.logit ( logrec->ids ( 0 ), ll );
585        }
586}
587
588void BM::bayes_batch ( const mat &Data, const vec &cond ) {
589        for ( int t = 0; t < Data.cols(); t++ ) {
590                bayes ( Data.get_col ( t ), cond );
591        }
592}
593
594void BM::bayes_batch ( const mat &Data, const mat &Cond ) {
595        for ( int t = 0; t < Data.cols(); t++ ) {
596                bayes ( Data.get_col ( t ), Cond.get_col ( t ) );
597        }
598}
599
600}
Note: See TracBrowser for help on using the browser.