#include "mixef.h"
#include <vector>

using namespace itpp;


void MixEF::init ( BMEF* Com0, const mat &Data, int c ) {
	//prepare sizes
	Coms.set_size ( c );
	n=c;
	weights.set_parameters ( 1/ ( double ) c*ones ( c ) ); //assume at least one observation in each comp.
	//est will be done at the end
	//
	int i;
	int ndat = Data.cols();
	//Estimate  Com0 from all data
	Coms ( 0 ) = ( BMEF* ) Com0->_copy_();
//	Coms(0)->set_evalll(false);
	Coms ( 0 )->bayesB ( Data );
	// Flatten it to its original shape
	Coms ( 0 )->flatten ( Com0 );

	//Copy it to the rest 
	for ( i=1;i<n;i++ ) {
		//copy Com0 and create new rvs for them
		Coms ( i ) = ( BMEF* ) Coms ( 0 )->_copy_ ( true );
	}
	//Pick some data for each component and update it
	for ( i=0;i<n;i++ ) {
		//pick one datum
		int ind=ndat*UniRNG.sample();
		Coms ( i )->bayes ( Data.get_col ( ind ),ndat/n );
	}

	//est already exists - must be deleted before build_est() can be used
	delete est;
	build_est();

}
void MixEF::bayesB ( const mat &Data ) {
	this->bayes ( Data );
}

void MixEF::bayes ( const vec &data ) {

};

void MixEF::bayes ( const mat &data ) {
	int ndat=data.cols();
	int t,i,niter;
	bool converged;

	multiBM weights0 ( weights );

	Array<BMEF*> Coms0 ( n );
	for ( i=0;i<n;i++ ) {Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy_();}

	niter=0;
	mat W=ones ( n,ndat ) / n;
	mat Wlast=ones ( n,ndat ) / n;
	vec w ( n );
	vec ll ( n );
	// tmp for weights
	vec wtmp = zeros ( n );
	//Estim
	while ( !converged ) {
		// Copy components back to their initial values
		// All necessary information is now in w and Coms0.
		Wlast = W;
		//
		for ( t=0;t<ndat;t++ ) {
			for ( i=0;i<n;i++ ) {
				ll ( i ) =Coms ( i )->logpred ( data.get_col ( t ) );
				wtmp =0.0; wtmp ( i ) =1.0;
				ll ( i ) += weights.logpred ( wtmp );
			}
			w = exp ( ll-max ( ll ) );
			W.set_col ( t, w/sum ( w ) );
		}

		for ( i=0;i<n;i++ ) {
			Coms ( i )-> set_statistics ( Coms0 ( i ) );
		}
		weights.set_statistics ( &weights0 );

		for ( t=0;t<ndat;t++ ) {
			for ( i=0;i<n;i++ ) {
				Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) );
			}
			weights.bayes ( W.get_col ( t ) );
		}

		niter++;
		//TODO better convergence rule.
		converged = ( sumsum ( abs ( W-Wlast ) ) /n<0.001 );
	}

	//Clean Coms0
	for ( i=0;i<n;i++ ) {delete Coms0 ( i );}
};


double MixEF::logpred ( const vec &dt ) const {

	vec w=weights._epdf().mean();
	double exLL=0.0;
	for ( int i=0;i<n;i++ ) {
		exLL+=w ( i ) *exp ( Coms ( i )->logpred ( dt ) );
	}
	return log ( exLL );
}
