#include #include "mixtures.h" namespace bdm { void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) { //prepare sizes Coms.set_size ( c ); weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp. multiBM weights0(weights); //est will be done at the end // int i; int ndat = Data.cols(); //Estimate Com0 from all data Coms ( 0 ) = (BMEF*) Com0->_copy(); // Coms(0)->set_evalll(false); Coms ( 0 )->bayes_batch ( Data ); Coms ( 0 )->flatten ( Com0 ); //Copy it to the rest for ( i = 1; i < Coms.length(); i++ ) { //copy Com0 and create new rvs for them Coms ( i ) = (BMEF*) Coms ( 0 )->_copy ( ); } //Pick some data for each component and update it for ( i = 0; i < Coms.length(); i++ ) { //pick one datum if (ndat==Coms.length()) { //take the ith vector Coms ( i )->bayes ( Data.get_col ( i ), empty_vec ); } else { // pick at random int ind = (int) floor ( ndat * UniRNG.sample() ); Coms ( i )->bayes_weighted ( Data.get_col ( ind ), empty_vec, ndat ); Coms (i)->flatten(Com0,ndat/Coms.length()); } //sharpen to the sharp component //Coms ( i )->flatten ( SharpCom.get(), 1.0/Coms.length() ); } Options old_opt =options; Options ini_opt=options; ini_opt.method = EM; ini_opt.max_niter= 1; bayes_batch(Data, empty_vec); for ( i = 0; i < Coms.length(); i++ ) { Coms (i)->flatten(Com0,ndat/Coms.length()); } options = old_opt; } double MixEF::bayes_batch_weighted ( const mat &data , const mat &cond, const vec &wData ) { int ndat = data.cols(); int t, i, niter; bool converged = false; multiBM weights0 ( weights ); int n = Coms.length(); Array Coms0 ( n ); for ( i = 0; i < n; i++ ) { Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy(); } niter = 0; mat W = ones ( n, ndat ) / n; mat Wlast = ones ( n, ndat ) / n; vec w ( n ); vec ll ( n ); // tmp for weights vec wtmp = zeros ( n ); int maxi; double maxll; double levid=0.0; //Estim while ( !converged ) { levid=0.0; // Copy components back to their initial values // All necessary information is now in w and Coms0. Wlast = W; // //#pragma omp parallel for for ( t = 0; t < ndat; t++ ) { //#pragma omp parallel for for ( i = 0; i < n; i++ ) { ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) , empty_vec); wtmp = 0.0; wtmp ( i ) = 1.0; ll ( i ) += weights.logpred ( wtmp ); } maxll = max ( ll, maxi ); switch ( options.method ) { case QB: w = exp ( ll - maxll ); w /= sum ( w ); break; case EM: w = 0.0; w ( maxi ) = 1.0; break; } W.set_col ( t, w ); } // copy initial statistics //#pragma omp parallel for for ( i = 0; i < n; i++ ) { Coms ( i )-> set_statistics ( Coms0 ( i ) ); } weights.set_statistics ( &weights0 ); // Update statistics // !!!! note wData ==> this is extra weight of the data record // !!!! For typical cases wData=1. vec logevid(n); for ( t = 0; t < ndat; t++ ) { //#pragma omp parallel for for ( i = 0; i < n; i++ ) { Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) ); logevid(i) = Coms(i)->_ll(); } weights.bayes ( W.get_col ( t ) * wData ( t ) ); } levid += weights._ll()+log(weights.posterior().mean() * exp(logevid)); // inner product w*exp(evid) niter++; //TODO better convergence rule. converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 ); } //Clean Coms0 for ( i = 0; i < n; i++ ) { delete Coms0 ( i ); } return levid; } void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) { }; double MixEF::logpred ( const vec &yt, const vec &cond =empty_vec) const { vec w = weights.posterior().mean(); double exLL = 0.0; for ( int i = 0; i < Coms.length(); i++ ) { exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt , cond ) ); } return log ( exLL ); } emix* MixEF::epredictor ( const vec &vec) const { Array > pC ( Coms.length() ); for ( int i = 0; i < Coms.length(); i++ ) { pC ( i ) = Coms ( i )->epredictor ( ); pC (i) -> set_rv(_yrv()); } emix* tmp; tmp = new emix( ); tmp->_w() = weights.posterior().mean(); tmp->_Coms() = pC; tmp->validate(); return tmp; } void MixEF::flatten ( const BMEF* M2, double weight=1.0 ) { const MixEF* Mix2 = dynamic_cast ( M2 ); bdm_assert_debug ( Mix2->Coms.length() == Coms.length(), "Different no of coms" ); //Flatten each component for ( int i = 0; i < Coms.length(); i++ ) { Coms ( i )->flatten ( Mix2->Coms ( i ) , weight); } //Flatten weights = make them equal!! weights.set_statistics ( & ( Mix2->weights ) ); } }