root/library/bdm/base/bdmbase.cpp @ 768

Revision 768, 15.4 kB (checked in by mido, 14 years ago)

minor changes, BM::to_setting improved (BUT NOT TESTED)

  • Property svn:eol-style set to native
Line 
1
2#include "bdmbase.h"
3
4//! Space of basic BDM structures
5namespace bdm {
6
7const int RV::BUFFER_STEP = 1;
8
9Array<string> RV::NAMES ( RV::BUFFER_STEP );
10
11ivec RV::SIZES ( RV::BUFFER_STEP );
12
13RV::str2int_map RV::MAP;
14
15void RV::clear_all() {
16        MAP.clear();
17        SIZES.clear();
18        NAMES = Array<string> ( BUFFER_STEP );
19}
20
21string RV::show_all() {
22        ostringstream os;
23        for ( str2int_map::const_iterator iter = MAP.begin(); iter != MAP.end(); iter++ ) {
24                os << "key: " << iter->first << " val: " << iter->second << endl;
25        }
26        return os.str();
27};
28
29int RV::assign_id( const string &name, int size ) {
30        //Refer
31        int id;
32        str2int_map::const_iterator iter = MAP.find ( name );
33        if ( iter == MAP.end() || name.length() == 0 ) { //add new RV
34                id = MAP.size() + 1;
35                //debug
36                /*              {
37                                        cout << endl;
38                                        str2int_map::const_iterator iter = MAP.begin();
39                                        for(str2int_map::const_iterator iter=MAP.begin(); iter!=MAP.end(); iter++){
40                                                cout << "key: " << iter->first << " val: " << iter->second <<endl;
41                                        }
42                                }*/
43
44                MAP.insert ( make_pair ( name, id ) ); //add new rv
45                if ( id >= NAMES.length() ) {
46                        NAMES.set_length ( id + BUFFER_STEP, true );
47                        SIZES.set_length ( id + BUFFER_STEP, true );
48                }
49                NAMES ( id ) = name;
50                SIZES ( id ) = size;
51                bdm_assert ( size > 0, "RV " + name + " does not exists. Default size (-1) can not be assigned " );
52        } else {
53                id = iter->second;
54                if ( size > 0 && name.length() > 0 ) {
55                        bdm_assert ( SIZES ( id ) == size, "RV " + name + " of size " + num2str ( SIZES ( id ) ) + " exists, requested size " + num2str ( size ) + "can not be assigned" );
56                }
57        }
58        return id;
59};
60
61int RV::countsize() const {
62        int tmp = 0;
63        for ( int i = 0; i < len; i++ ) {
64                tmp += SIZES ( ids ( i ) );
65        }
66        return tmp;
67}
68
69ivec RV::cumsizes() const {
70        ivec szs ( len );
71        int tmp = 0;
72        for ( int i = 0; i < len; i++ ) {
73                tmp += SIZES ( ids ( i ) );
74                szs ( i ) = tmp;
75        }
76        return szs;
77}
78
79void RV::init ( const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times ) {
80        len = in_names.length();
81        bdm_assert ( in_names.length() == in_times.length(), "check \"times\" " );
82        bdm_assert ( in_names.length() == in_sizes.length(), "check \"sizes\" " );
83
84        times.set_length ( len );
85        ids.set_length ( len );
86        int id;
87        for ( int i = 0; i < len; i++ ) {
88                id = assign_id ( in_names ( i ), in_sizes ( i ) );
89                ids ( i ) = id;
90        }
91        times = in_times;
92        dsize = countsize();
93}
94
95RV::RV ( string name, int sz, int tm ) {
96        Array<string> A ( 1 );
97        A ( 0 ) = name;
98        init ( A, vec_1 ( sz ), vec_1 ( tm ) );
99}
100
101bool RV::add ( const RV &rv2 ) {
102        if ( rv2.len > 0 ) { //rv2 is nonempty
103                ivec ind = rv2.findself ( *this ); //should be -1 all the time
104                ivec index = itpp::find ( ind == -1 );
105
106                if ( index.length() < rv2.len ) { //conflict
107                        ids = concat ( ids, rv2.ids ( index ) );
108                        times = concat ( times, rv2.times ( index ) );
109                } else {
110                        ids = concat ( ids, rv2.ids );
111                        times = concat ( times, rv2.times );
112                }
113                len = ids.length();
114                dsize = countsize();
115                return ( index.length() == rv2.len ); //conflict or not
116        } else { //rv2 is empty
117                return true; // no conflict
118        }
119};
120
121RV RV::subselect ( const ivec &ind ) const {
122        RV ret;
123        ret.ids = ids ( ind );
124        ret.times = times ( ind );
125        ret.len = ind.length();
126        ret.dsize = ret.countsize();
127        return ret;
128}
129
130RV RV::operator() ( int di1, int di2 ) const {
131        ivec sz = cumsizes();
132        int i1 = 0;
133        while ( sz ( i1 ) < di1 ) i1++;
134        int i2 = i1;
135        while ( sz ( i2 ) < di2 ) i2++;
136        return subselect ( linspace ( i1, i2 ) );
137}
138
139void RV::t_plus ( int delta ) {
140        times += delta;
141}
142
143bool RV::equal ( const RV &rv2 ) const {
144        return ( ids == rv2.ids ) && ( times == rv2.times );
145}
146
147shared_ptr<pdf> epdf::condition ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<pdf>() );
148
149
150shared_ptr<epdf> epdf::marginal ( const RV &rv ) const NOT_IMPLEMENTED( shared_ptr<epdf>() );
151
152mat epdf::sample_mat ( int N ) const {
153        mat X = zeros ( dim, N );
154        for ( int i = 0; i < N; i++ ) X.set_col ( i, this->sample() );
155        return X;
156}
157
158vec epdf::evallog_mat ( const mat &Val ) const {
159        vec x ( Val.cols() );
160        for ( int i = 0; i < Val.cols(); i++ ) {
161                x ( i ) = evallog ( Val.get_col ( i ) );
162        }
163
164        return x;
165}
166
167vec epdf::evallog_mat ( const Array<vec> &Avec ) const {
168        vec x ( Avec.size() );
169        for ( int i = 0; i < Avec.size(); i++ ) {
170                x ( i ) = evallog ( Avec ( i ) );
171        }
172
173        return x;
174}
175
176mat pdf::samplecond_mat ( const vec &cond, int N ) {
177        mat M ( dimension(), N );
178        for ( int i = 0; i < N; i++ ) {
179                M.set_col ( i, samplecond ( cond ) );
180        }
181
182        return M;
183}
184
185void pdf::from_setting ( const Setting &set ) {
186        root::from_setting( set );
187        shared_ptr<RV> r = UI::build<RV> ( set, "rv", UI::optional );
188        if ( r ) {
189                set_rv ( *r );
190        }
191
192        r = UI::build<RV> ( set, "rvc", UI::optional );
193        if ( r ) {
194                set_rvc ( *r );
195        }
196}
197
198void pdf::to_setting ( Setting &set ) const {   
199        root::to_setting( set );
200        UI::save( &rv, set, "rv" );
201        UI::save( &rvc, set, "rvc" );
202}
203
204void datalink::set_connection ( const RV &rv, const RV &rv_up ) {
205        downsize = rv._dsize();
206        upsize = rv_up._dsize();
207        v2v_up = rv.dataind ( rv_up );
208        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
209}
210
211void datalink::set_connection ( int ds, int us, const ivec &upind ) {
212        downsize = ds;
213        upsize = us;
214        v2v_up = upind;
215        bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
216}
217
218void datalink_part::set_connection ( const RV &rv, const RV &rv_up ) {
219        rv.dataind ( rv_up, v2v_down, v2v_up );
220        downsize = v2v_down.length();
221        upsize = v2v_up.length();
222}
223
224void datalink_m2e::set_connection ( const RV &rv, const RV &rvc, const RV &rv_up ) {
225        datalink::set_connection ( rv, rv_up );
226        condsize = rvc._dsize();
227        //establish v2c connection
228        rvc.dataind ( rv_up, v2c_lo, v2c_up );
229}
230
231vec datalink_m2e::get_cond ( const vec &val_up ) {
232        vec tmp ( condsize );
233        set_subvector ( tmp, v2c_lo, val_up ( v2c_up ) );
234        return tmp;
235}
236
237void datalink_m2e::pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
238        bdm_assert_debug ( downsize == val.length(), "Wrong val" );
239        bdm_assert_debug ( upsize == val_up.length(), "Wrong val_up" );
240        set_subvector ( val_up, v2v_up, val );
241        set_subvector ( val_up, v2c_up, cond );
242}
243
244std::ostream &operator<< ( std::ostream &os, const RV &rv ) {
245        int id;
246        for ( int i = 0; i < rv.len ; i++ ) {
247                id = rv.ids ( i );
248                os << id << "(" << RV::SIZES ( id ) << ")" <<  // id(size)=
249                "=" << RV::NAMES ( id )  << "_{"  << rv.times ( i ) << "}; "; //name_{time}
250        }
251        return os;
252}
253
254RV RV::expand_delayes() const {
255        RV rvt = this->remove_time(); //rv at t=0
256        RV tmp = rvt;
257        int td = mint();
258        for ( int i = -1; i >= td; i-- ) {
259                rvt.t_plus ( -1 );
260                tmp.add ( rvt ); //shift u1
261        }
262        return tmp;
263}
264
265str RV::tostr() const {
266        ivec idlist ( dsize );
267        ivec tmlist ( dsize );
268        int i;
269        int pos = 0;
270        for ( i = 0; i < len; i++ ) {
271                idlist.set_subvector ( pos, pos + size ( i ) - 1, ids ( i ) );
272                tmlist.set_subvector ( pos, pos + size ( i ) - 1, times ( i ) );
273                pos += size ( i );
274        }
275        return str ( idlist, tmlist );
276}
277
278ivec RV::dataind ( const RV &rv2 ) const {
279        ivec res ( 0 );
280        if ( rv2._dsize() > 0 ) {
281                str str2 = rv2.tostr();
282                ivec part;
283                int i;
284                for ( i = 0; i < len; i++ ) {
285                        part = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
286                        res = concat ( res, part );
287                }
288        }
289
290        bdm_assert_debug ( res.length() == dsize, "this rv is not fully present in crv!" );
291        return res;
292
293}
294
295void RV::dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const {
296        //clean results
297        selfi.set_size ( 0 );
298        rv2i.set_size ( 0 );
299
300        // just in case any rv is empty
301        if ( ( len == 0 ) || ( rv2.length() == 0 ) ) {
302                return;
303        }
304
305        //find comon rv
306        ivec cids = itpp::find ( this->findself ( rv2 ) >= 0 );
307
308        // index of
309        if ( cids.length() > 0 ) {
310                str str1 = tostr();
311                str str2 = rv2.tostr();
312
313                ivec part1;
314                ivec part2;
315                int i, j;
316                // find common rv in strs
317                for ( j = 0; j < cids.length(); j++ ) {
318                        i = cids ( j );
319                        part1 = itpp::find ( ( str1.ids == ids ( i ) ) & ( str1.times == times ( i ) ) );
320                        part2 = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
321                        selfi = concat ( selfi, part1 );
322                        rv2i = concat ( rv2i, part2 );
323                }
324        }
325        bdm_assert_debug ( selfi.length() == rv2i.length(), "this should not happen!" );
326}
327
328RV RV::subt ( const RV &rv2 ) const {
329        ivec res = this->findself ( rv2 ); // nonzeros
330        ivec valid;
331        if ( dsize > 0 ) {
332                valid = itpp::find ( res == -1 );    //-1 => value not found => it remains
333        }
334        return ( *this ) ( valid ); //keep those that were not found in rv2
335}
336
337std::string RV::scalarname ( int scalat ) const {
338        bdm_assert ( scalat < dsize, "Wrong input index" );
339        int id = 0;
340        int scalid = 0;
341        while ( scalid + SIZES ( ids ( id ) ) <= scalat )  {
342                scalid += SIZES ( ids ( id ) );
343                id++;
344        };
345        //now id is the id of variable of interest
346        if ( size ( id ) == 1 )
347                return  NAMES ( ids ( id ) );
348        else
349                return  NAMES ( ids ( id ) ) + "_" + num2str ( scalat - scalid );
350
351}
352
353ivec RV::findself ( const RV &rv2 ) const {
354        int i, j;
355        ivec tmp = -ones_i ( len );
356        for ( i = 0; i < len; i++ ) {
357                for ( j = 0; j < rv2.length(); j++ ) {
358                        if ( ( ids ( i ) == rv2.ids ( j ) ) & ( times ( i ) == rv2.times ( j ) ) ) {
359                                tmp ( i ) = j;
360                                break;
361                        }
362                }
363        }
364        return tmp;
365}
366
367ivec RV::findself_ids ( const RV &rv2 ) const {
368        int i, j;
369        ivec tmp = -ones_i ( len );
370        for ( i = 0; i < len; i++ ) {
371                for ( j = 0; j < rv2.length(); j++ ) {
372                        if ( ( ids ( i ) == rv2.ids ( j ) ) ) {
373                                tmp ( i ) = j;
374                                break;
375                        }
376                }
377        }
378        return tmp;
379}
380
381void RV::from_setting ( const Setting &set ) {
382        Array<string> A;
383        UI::get ( A, set, "names" );
384
385        ivec szs;
386        if ( !UI::get ( szs, set, "sizes" ) )
387                szs = ones_i ( A.length() );
388
389        ivec tms;
390        if ( !UI::get ( tms, set, "times" ) )
391                tms = zeros_i ( A.length() );
392
393        init ( A, szs, tms );
394}
395
396void RV::to_setting ( Setting &set ) const {
397        Array<string> names ( len );
398        ivec sizes ( len );
399        for ( int i = 0; i < len; i++ ) {
400                names ( i ) = name ( i );
401                sizes ( i ) = size ( i );
402        }
403        UI::save ( names, set, "names" );
404        UI::save ( sizes, set, "sizes" );
405        UI::save ( times, set, "times" );
406}
407
408RV concat ( const RV &rv1, const RV &rv2 ) {
409        RV pom = rv1;
410        pom.add ( rv2 );
411        return pom;
412}
413
414RV get_composite_rv ( const Array<shared_ptr<pdf> > &pdfs,
415                      bool checkoverlap ) {
416        RV rv; //empty rv
417        bool rvaddok;
418        for ( int i = 0; i < pdfs.length(); i++ ) {
419                rvaddok = rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
420                // If rvaddok==false, pdfs overlap => assert error.
421                bdm_assert_debug ( rvaddok || !checkoverlap, "mprod::mprod() input pdfs overlap in rv!" );
422        }
423
424        return rv;
425}
426
427int logger::add_vector ( const RV &rv, string prefix ) {
428        int id;
429        if ( rv._dsize() > 0 ) {
430                id = entries.length();
431                names = concat ( names, prefix ); // diff
432                entries.set_length ( id + 1, true );
433                entries ( id ) = rv;
434        } else {
435                id = -1;
436        }
437        return id; // identifier of the last entry
438}
439
440int logger::add_setting ( const string &prefix ) {
441        Setting &root = setting_conf.getRoot();
442        int id = root.getLength(); //root must be group!!
443        if ( prefix.length() > 0 ) {
444                settings.set_length ( id + 1, true );
445                settings ( id ) = &root.add ( prefix, Setting::TypeList );
446        } else {
447                id = -1;
448        }
449        return id;
450}
451
452void epdf::log_register ( logger &L, const string &prefix ) {
453        RV r;
454        if ( isnamed() ) {
455                r = _rv();
456        } else {
457                r = RV ( "", dimension() );
458        };
459        root::log_register ( L, prefix );
460
461        // log full data
462        if ( log_level == 10 ) {
463                logrec->ids.set_size ( 1 );
464                logrec->ids ( 0 ) = logrec->L.add_setting ( prefix );
465        } else {
466                // log only
467                logrec->ids.set_size ( 3 );
468                if ( log_level > 0 ) {
469                        logrec->ids ( 0 ) = logrec->L.add_vector ( r, prefix + logrec->L.prefix_sep() + "mean" );
470                }
471                if ( log_level > 1 ) {
472                        logrec->ids ( 1 ) = logrec->L.add_vector ( r, prefix + logrec->L.prefix_sep() + "lb" );
473                        logrec->ids ( 2 ) = logrec->L.add_vector ( r, prefix + logrec->L.prefix_sep() + "ub" );
474                }
475        }
476}
477
478void epdf::log_write() const {
479        if ( log_level == 10 ) {
480                to_setting ( logrec->L.log_to_setting ( logrec->ids ( 0 ) ) );
481        } else {
482                if ( log_level > 0 ) {
483                        logrec->L.log_vector ( logrec->ids ( 0 ), mean() );
484                }
485                if ( log_level > 1 ) {
486                        vec lb;
487                        vec ub;
488                        qbounds ( lb, ub );
489                        logrec->L.log_vector ( logrec->ids ( 1 ), lb );
490                        logrec->L.log_vector ( logrec->ids ( 2 ), ub );
491                }
492        }
493}
494
495void datalink_buffered::set_connection ( const RV &rv, const RV &rv_up ) {
496        // create link between up and down
497        datalink_part::set_connection ( rv, rv_up); // only non-delayed version
498
499        // create rvs of history
500        // we can store only what we get in rv_up - everything else is removed
501        ivec valid_ids = rv.findself_ids ( rv_up ); // return on which position each id is
502        RV rv_hist = rv.subselect ( find ( valid_ids >= 0 ) ); // select only rvs that are in rv_up, ie ind>0
503        RV rv_hist0 = rv_hist.remove_time(); // these RVs will form history at time =0
504        // now we need to know what is needed from Up
505        rv_hist = rv_hist.expand_delayes(); // full regressor - including time 0
506        Hrv = rv_hist.subt ( rv_hist0 );   // remove time 0
507        history = zeros ( Hrv._dsize() );
508
509        // decide if we need to copy val to history
510        if ( Hrv._dsize() > 0 ) {
511                v2h_up = rv_hist0.dataind ( rv_up ); // indeces of elements of rv_up to be copied
512        } // else v2h_up is empty
513
514        Hrv.dataind ( rv, h2v_hist, h2v_down );
515
516        downsize = v2v_down.length() + h2v_down.length();
517        upsize = v2v_up.length();
518}
519
520void datalink_buffered::set_history ( const RV& rv1, const vec &hist0 ) {
521        bdm_assert ( rv1._dsize() == hist0.length(), "hist is not compatible with given rv1" );
522        ivec ind_H;
523        ivec ind_h0;
524        Hrv.dataind ( rv1, ind_H, ind_h0 ); // find indeces of rv in
525        set_subvector ( history, ind_H, hist0 ( ind_h0 ) ); // copy given hist to appropriate places
526}
527
528void DS::log_register ( logger &L,  const string &prefix ) {
529        bdm_assert ( ytsize == Yrv._dsize(), "invalid DS: ytsize (" + num2str ( ytsize ) + ") different from Drv " + num2str ( Yrv._dsize() ) );
530        bdm_assert ( utsize == Urv._dsize(), "invalid DS: utsize (" + num2str ( utsize ) + ") different from Urv " + num2str ( Urv._dsize() ) );
531
532        root::log_register ( L, prefix );
533        //we know that
534        if ( log_level > 0 ) {
535                logrec->ids.set_size ( 2 );
536                logrec->ids ( 0 ) = logrec->L.add_vector ( Yrv, prefix );
537                logrec->ids ( 1 ) = logrec->L.add_vector ( Urv, prefix );
538        }
539}
540
541void DS::log_write ( ) const {
542        if ( log_level > 0 ) {
543                vec tmp ( Yrv._dsize() + Urv._dsize() );
544                getdata ( tmp );
545                // d is first in getdata
546                logrec->L.log_vector ( logrec->ids ( 0 ), tmp.left ( Yrv._dsize() ) );
547                // u follows after d in getdata
548                logrec->L.log_vector ( logrec->ids ( 1 ), tmp.mid ( Yrv._dsize(), Urv._dsize() ) );
549        }
550}
551
552void BM::set_options ( const string &opt ) {
553        if ( opt.find ( "logfull" ) != string::npos ) {
554                const_cast<epdf&> ( posterior() ).set_log_level ( 10 ) ;
555        } else {
556                if ( opt.find ( "logbounds" ) != string::npos ) {
557                        const_cast<epdf&> ( posterior() ).set_log_level ( 2 ) ;
558                } else {
559                        const_cast<epdf&> ( posterior() ).set_log_level ( 1 ) ;
560                }
561                if ( opt.find ( "logll" ) != string::npos ) {
562                        log_level = 1;
563                }
564        }
565}
566
567void BM::log_register ( logger &L, const string &prefix ) {
568        root::log_register ( L, prefix );
569
570        const_cast<epdf&> ( posterior() ).log_register ( L, prefix + L.prefix_sep() + "apost" );
571
572        if ( ( log_level ) > 0 ) {
573                logrec->ids.set_size ( 1 );
574                logrec->ids ( 0 ) = L.add_vector ( RV ( "", 1 ), prefix + L.prefix_sep() + "ll" );
575        }
576}
577
578void BM::log_write ( ) const {
579        posterior().log_write();
580        if ( log_level > 0 ) {
581                logrec->L.logit ( logrec->ids ( 0 ), ll );
582        }
583}
584
585void BM::bayes_batch ( const mat &Data, const vec &cond ) {
586        for ( int t = 0; t < Data.cols(); t++ ) {
587                bayes ( Data.get_col ( t ), cond );
588        }
589}
590
591void BM::bayes_batch ( const mat &Data, const mat &Cond ) {
592        for ( int t = 0; t < Data.cols(); t++ ) {
593                bayes ( Data.get_col ( t ), Cond.get_col ( t ) );
594        }
595}
596
597}
Note: See TracBrowser for help on using the browser.