#include "emix.h"

namespace bdm {

void emix_base::validate (){
	epdf::validate();
	bdm_assert ( no_coms() > 0, "There has to be at least one component." );

	bdm_assert ( no_coms() == w.length(), "It is obligatory to define weights of all the components." );

	double sum_w = sum ( w );
	bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
	w = w / sum_w;

	dim = component ( 0 )->dimension();
	RV rv_tmp = component ( 0 )->_rv() ;
	bool isnamed = component( 0 )->isnamed();
	for ( int i = 1; i < no_coms(); i++ ) {
		bdm_assert ( dim == ( component ( i )->dimension() ), "Component sizes do not match!" );
		isnamed &= component(i)->isnamed() & component(i)->_rv().equal(rv_tmp);
	}
	if (isnamed)
		epdf::set_rv ( rv_tmp); 
}



vec emix_base::sample() const {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
#pragma omp critical
	u0 = UniRNG.sample();

	int i = 0;
	while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
		i++;
	}

	return component ( i )->sample();
}

vec emix_base::mean() const {
	int i;
	vec mu = zeros ( dim );
	for ( i = 0; i < w.length(); i++ ) {
		mu += w ( i ) * component ( i )->mean();
	}
	return mu;
}

vec emix_base::variance() const {
	//non-central moment
	vec mom2 = zeros ( dim );
	vec mom1 = zeros ( dim );
	
	for ( int i = 0; i < w.length(); i++ ) {
		vec vi=component( i )->variance();
		vec mi=component ( i )->mean();
		mom2 += w ( i ) * ( vi + pow ( mi, 2 ) );
		mom1 += w(i) * mi;
	}
	//central moment
	return mom2 - pow ( mom1, 2 );
}

double emix_base::evallog ( const vec &val ) const {
	int i;
	double sum = 0.0;
	for ( i = 0; i < w.length(); i++ ) {
		sum += w ( i ) * exp ( component ( i )->evallog ( val ) );
	}
	if ( sum == 0.0 ) {
		sum = std::numeric_limits<double>::epsilon();
	}
	double tmp = log ( sum );
	bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
	return tmp;
}

vec emix_base::evallog_mat ( const mat &Val ) const {
	vec x = zeros ( Val.cols() );
	for ( int i = 0; i < w.length(); i++ ) {
		x += w ( i ) * exp ( component( i )->evallog_mat ( Val ) );
	}
	return log ( x );
};

mat emix_base::evallog_coms ( const mat &Val ) const {
	mat X ( w.length(), Val.cols() );
	for ( int i = 0; i < w.length(); i++ ) {
		X.set_row ( i, w ( i ) *exp ( component( i )->evallog_mat ( Val ) ) );
	}
	return X;
}

shared_ptr<epdf> emix_base::marginal ( const RV &rv ) const {
	emix *tmp = new emix();
	shared_ptr<epdf> narrow ( tmp );
	marginal ( rv, *tmp );
	return narrow;
}

void emix_base::marginal ( const RV &rv, emix &target ) const {
	bdm_assert ( isnamed(), "rvs are not assigned" );

	Array<shared_ptr<epdf> > Cn ( no_coms() );
	for ( int i = 0; i < no_coms(); i++ ) {
		Cn ( i ) = component ( i )->marginal ( rv );
	}

	target._w() = w;
	target._Coms() = Cn;
	target.validate();
}

shared_ptr<pdf> emix_base::condition ( const RV &rv ) const {
	bdm_assert ( isnamed(), "rvs are not assigned" );
	mratio *tmp = new mratio ( this, rv );
	return shared_ptr<pdf> ( tmp );
}

void emix::from_setting ( const Setting &set ) {
	emix_base::from_setting(set);
	UI::get ( Coms, set, "pdfs", UI::compulsory );
	UI::get ( w, set, "weights", UI::compulsory );
}
void emix::to_setting  (Setting  &set) const {
	emix_base::to_setting(set);
	UI::save(Coms, set, "pdfs");
	UI::save( w, set, "weights");
}


void 	emix::validate (){
	emix_base::validate();
	dim = Coms ( 0 )->dimension();
}


double mprod::evallogcond ( const vec &val, const vec &cond ) {
	int i;
	double res = 0.0;
	for ( i = pdfs.length() - 1; i >= 0; i-- ) {
		/*			if ( pdfs(i)->_rvc().count() >0) {
						pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
					}
					// add logarithms
					res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
		res += pdfs ( i )->evallogcond (
		           dls ( i )->pushdown ( val ),
		           dls ( i )->get_cond ( val, cond )
		       );
	}
	return res;
}

vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
	vec tmp ( Dt.cols() );
	for ( int i = 0; i < Dt.cols(); i++ ) {
		tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
	}
	return tmp;
}

vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
	vec tmp ( Dt.length() );
	for ( int i = 0; i < Dt.length(); i++ ) {
		tmp ( i ) = evallogcond ( Dt ( i ), cond );
	}
	return tmp;
}

void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
	pdfs = mFacs;
	dls.set_size ( mFacs.length() );

	rv = get_composite_rv ( pdfs, true );
	dim = rv._dsize();

	for ( int i = 0; i < pdfs.length(); i++ ) {
		RV rvx = pdfs ( i )->_rvc().subt ( rv );
		rvc.add ( rvx ); // add rv to common rvc
	}
	dimc = rvc._dsize();

	// rv and rvc established = > we can link them with pdfs
	for ( int i = 0; i < pdfs.length(); i++ ) {
		dls ( i ) = new datalink_m2m;
		dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
	}
}

void mprod::from_setting ( const Setting &set ) {
		pdf::from_setting(set);
		Array<shared_ptr<pdf> > temp_array; 
		UI::get ( temp_array, set, "pdfs", UI::compulsory );
		set_elements ( temp_array );
	}
void 	mprod::to_setting  (Setting  &set) const {
		pdf::to_setting(set);
		UI::save( pdfs, set, "pdfs");
	}

void mmix::validate()
{	pdf::validate();
	bdm_assert ( Coms.length() > 0, "There has to be at least one component." );

	bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );

	double sum_w = sum ( w );
	bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
	w = w / sum_w;


	dim = Coms ( 0 )->dimension();
	dimc = Coms ( 0 )->dimensionc();
	RV rv_tmp = Coms ( 0 )->_rv();
	RV rvc_tmp = Coms ( 0 )->_rvc();
	bool isnamed = Coms ( 0 )->isnamed();
	for ( int i = 1; i < Coms.length(); i++ ) {
		bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
		bdm_assert ( dimc >= ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
		isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp) & Coms(i)->_rvc().equal(rvc_tmp);
	}
	if (isnamed)
	{
		pdf::set_rv ( rv_tmp );
		pdf::set_rvc ( rvc_tmp );
	}
}

void mmix::from_setting ( const Setting &set ) {
	
	pdf::from_setting(set);
	UI::get ( Coms, set, "pdfs", UI::compulsory );

	if ( !UI::get ( w, set, "weights", UI::optional ) ) {
		int len = Coms.length();
		w.set_length ( len );
		w = 1.0 / len;
	}
}

void 	mmix::to_setting  (Setting  &set) const {
	pdf::to_setting(set);
	UI::save( Coms, set, "pdfs");
	UI::save( w, set, "weights");
}

vec mmix::samplecond ( const vec &cond ) {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
#pragma omp critical
	u0 = UniRNG.sample();

	int i = 0;
	while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
		i++;
	}

	return Coms ( i )->samplecond ( cond );
}

vec eprod_base::mean() const {
	vec tmp ( dim );
	for ( int i = 0; i < no_factors(); i++ ) {
		vec pom = factor( i )->mean();
		dls ( i )->pushup ( tmp, pom );
	}
	return tmp;
}

vec eprod_base::variance() const {
	vec tmp ( dim ); //second moment
	for ( int i = 0; i < no_factors(); i++ ) {
		vec pom = factor ( i )->variance();
		dls ( i )->pushup ( tmp, pom ); //
	}
	return tmp;
}
vec eprod_base::sample() const {
	vec tmp ( dim );
	for ( int i = 0; i < no_factors(); i++ ) {
		vec pom = factor ( i )->sample();
		dls ( i )->pushup ( tmp, pom );
	}
	return tmp;
}
double eprod_base::evallog ( const vec &val ) const {
	double tmp = 0;
	for ( int i = 0; i < no_factors(); i++ ) {
		tmp += factor ( i )->evallog ( dls ( i )->pushdown ( val ) );
	}
	//bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
	return tmp;
}

}

