| 396 | |
| 397 | //! Base class for all BM running as parallel update of internal BMs |
| 398 | |
| 399 | class ProdBMBase : public BM { |
| 400 | protected : |
| 401 | Array<vec_from_vec> bm_yt; |
| 402 | Array<vec_from_2vec> bm_cond; |
| 403 | class eprod_bm : public eprod_base { |
| 404 | ProdBMBase & pb; |
| 405 | public : |
| 406 | eprod_bm(ProdBMBase &pb0): pb(pb0){} |
| 407 | const epdf* factor(int i ) const {return &(pb.bm(i)->posterior());} |
| 408 | const int no_factors() const {return pb.no_bms();} |
| 409 | } est; |
| 410 | public: |
| 411 | ProdBMBase():est(*this){} |
| 412 | virtual BM* bm(int i) NOT_IMPLEMENTED(NULL); |
| 413 | virtual int no_bms() const {return 0;} |
| 414 | const epdf& posterior() const {return est;} |
| 415 | void set_prior(const epdf *pri){ |
| 416 | const eprod_base* ep=dynamic_cast<const eprod_base*>(pri); |
| 417 | if (ep){ |
| 418 | bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+ |
| 419 | num2str(no_bms()) + "BMs"); |
| 420 | for (int i=0; i<no_bms(); i++){ |
| 421 | bm(i)->set_prior(ep->factor(i)); |
| 422 | } |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | void validate() { |
| 427 | est.validate(); |
| 428 | BM::validate(); |
| 429 | // set links |
| 430 | bm_yt.set_length(no_bms()); |
| 431 | bm_cond.set_length(no_bms()); |
| 432 | |
| 433 | // |
| 434 | |
| 435 | for (int i=0; i<no_bms(); i++){ |
| 436 | yrv.add(bm(i)->_yrv()); |
| 437 | rvc.add(bm(i)->_rvc()); |
| 438 | } |
| 439 | rvc=rvc.subt(yrv); |
| 440 | |
| 441 | dimy = yrv._dsize(); |
| 442 | dimc = rvc._dsize(); |
| 443 | |
| 444 | for (int i=0; i<no_bms(); i++){ |
| 445 | bm_yt(i).connect(bm(i)->_yrv(), yrv); |
| 446 | bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc); |
| 447 | } |
| 448 | } |
| 449 | void bayes(const vec &dt, const vec &cond){ |
| 450 | ll=0; |
| 451 | for(int i=0;i<no_bms(); i++){ |
| 452 | bm_yt(i).update(dt); |
| 453 | bm_cond(i).update(dt,cond); |
| 454 | bm(i)->bayes(bm_yt(i), bm_cond(i)); |
| 455 | } |
| 456 | } |
| 457 | |
| 458 | }; |
| 459 | |
| 460 | class ProdBM: public ProdBMBase{ |
| 461 | protected: |
| 462 | Array<shared_ptr<BM> > BMs; |
| 463 | public: |
| 464 | virtual BM* bm(int i) {return BMs(i).get();} |
| 465 | virtual int no_bms() const {return BMs.length();} |
| 466 | void from_setting(const Setting &set){ |
| 467 | BM::from_setting(set); |
| 468 | UI::get(BMs,set,"BMs"); |
| 469 | } |
| 470 | void to_setting(Setting &set) const{ |
| 471 | BM::to_setting(set); |
| 472 | UI::save(BMs,set,"BMs"); |
| 473 | } |
| 474 | |
| 475 | }; |
| 476 | UIREGISTER(ProdBM); |
| 477 | |