#include "emix.h"

namespace bdm {

void emix::validate (){
	bdm_assert ( Coms.length() > 0, "There has to be at least one component." );

	bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );

	double sum_w = sum ( w );
	bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
	w = w / sum_w;

	int dim = Coms ( 0 )->dimension();
	for ( int i = 1; i < Coms.length(); i++ ) {
		bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
		bdm_assert ( Coms(i)->isnamed() , "An unnamed component is forbidden here!" );
	}

	set_rv ( Coms ( 0 )->_rv() ); 
}

void emix::from_setting ( const Setting &set ) {
	UI::get ( Coms, set, "pdfs", UI::compulsory );

	if ( !UI::get ( w, set, "weights", UI::optional ) ) {
		int len = Coms.length();
		w.set_length ( len );
		w = 1.0 / len;
	}

	validate();
}


vec emix::sample() const {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
#pragma omp critical
	u0 = UniRNG.sample();

	int i = 0;
	while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
		i++;
	}

	return Coms ( i )->sample();
}

vec emix::mean() const {
	int i;
	vec mu = zeros ( dim );
	for ( i = 0; i < w.length(); i++ ) {
		mu += w ( i ) * Coms ( i )->mean();
	}
	return mu;
}

vec emix::variance() const {
	//non-central moment
	vec mom2 = zeros ( dim );
	for ( int i = 0; i < w.length(); i++ ) {
		mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
	}
	//central moment
	return mom2 - pow ( mean(), 2 );
}

double emix::evallog ( const vec &val ) const {
	int i;
	double sum = 0.0;
	for ( i = 0; i < w.length(); i++ ) {
		sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
	}
	if ( sum == 0.0 ) {
		sum = std::numeric_limits<double>::epsilon();
	}
	double tmp = log ( sum );
	bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
	return tmp;
}

vec emix::evallog_mat ( const mat &Val ) const {
	vec x = zeros ( Val.cols() );
	for ( int i = 0; i < w.length(); i++ ) {
		x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
	}
	return log ( x );
};

mat emix::evallog_coms ( const mat &Val ) const {
	mat X ( w.length(), Val.cols() );
	for ( int i = 0; i < w.length(); i++ ) {
		X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
	}
	return X;
}

shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
	emix *tmp = new emix();
	shared_ptr<epdf> narrow ( tmp );
	marginal ( rv, *tmp );
	return narrow;
}

void emix::marginal ( const RV &rv, emix &target ) const {
	bdm_assert ( isnamed(), "rvs are not assigned" );

	Array<shared_ptr<epdf> > Cn ( Coms.length() );
	for ( int i = 0; i < Coms.length(); i++ ) {
		Cn ( i ) = Coms ( i )->marginal ( rv );
	}

	target._w() = w;
	target._Coms() = Cn;
	target.validate();
}

shared_ptr<pdf> emix::condition ( const RV &rv ) const {
	bdm_assert ( isnamed(), "rvs are not assigned" );
	mratio *tmp = new mratio ( this, rv );
	return shared_ptr<pdf> ( tmp );
}

void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
	w = w0 / sum ( w0 );
	dim = Coms0 ( 0 )->dimension();
	int i;
	for ( i = 0; i < w.length(); i++ ) {
		bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
	}
	if ( copy ) {
		Coms.set_length ( Coms0.length() );
		for ( i = 0; i < w.length(); i++ ) {
			bdm_error ( "Not implemented" );
			// *Coms ( i ) = *Coms0 ( i );
		}
		destroyComs = true;
	} else {
		Coms = Coms0;
		destroyComs = false;
	}
}

vec egiwmix::sample() const {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
#pragma omp critical
	u0 = UniRNG.sample();

	int i = 0;
	while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
		i++;
	}

	return Coms ( i )->sample();
}

vec egiwmix::mean() const {
	int i;
	vec mu = zeros ( dim );
	for ( i = 0; i < w.length(); i++ ) {
		mu += w ( i ) * Coms ( i )->mean();
	}
	return mu;
}

vec egiwmix::variance() const {
	// non-central moment
	vec mom2 = zeros ( dim );
	for ( int i = 0; i < w.length(); i++ ) {
		// pow is overloaded, we have to use another approach
		mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
	}
	// central moment
	// pow is overloaded, we have to use another approach
	return mom2 - elem_mult ( mean(), mean() );
}

shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
	emix *tmp = new emix();
	shared_ptr<epdf> narrow ( tmp );
	marginal ( rv, *tmp );
	return narrow;
}

void egiwmix::marginal ( const RV &rv, emix &target ) const {
	bdm_assert_debug ( isnamed(), "rvs are not assigned" );

	Array<shared_ptr<epdf> > Cn ( Coms.length() );
	for ( int i = 0; i < Coms.length(); i++ ) {
		Cn ( i ) = Coms ( i )->marginal ( rv );
	}

	target._w() = w;
	target._Coms() = Cn;
	target.validate();
}

egiw* 	egiwmix::approx() {
	// NB: dimx == 1 !!!
	// The following code might look a bit spaghetti-like,
	// consult Dedecius, K. et al.: Partial forgetting in AR models.

	double sumVecCommon;			      	// common part for many terms in eq.
	int len = w.length();				// no. of mix components
	int dimLS = Coms ( 1 )->_V()._D().length() - 1; 	// dim of LS
	vec vecNu ( len );					// vector of dfms of components
	vec vecD ( len );					// vector of LS reminders of comps.
	vec vecCommon ( len );				// vector of common parts
	mat matVecsTheta;				// matrix which rows are theta vects.

	// fill in the vectors vecNu, vecD and matVecsTheta
	for ( int i = 0; i < len; i++ ) {
		vecNu.shift_left ( Coms ( i )->_nu() );
		vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
		matVecsTheta.append_row ( Coms ( i )->est_theta() );
	}

	// calculate the common parts and their sum
	vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
	sumVecCommon = sum ( vecCommon );

	// LS estimator of theta
	vec aprEstTheta ( dimLS );
	aprEstTheta.zeros();
	for ( int i = 0; i < len; i++ ) {
		aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
	}
	aprEstTheta /= sumVecCommon;


	// LS estimator of dfm
	double aprNu;
	double A = log ( sumVecCommon );		// Term 'A' in equation

	for ( int i = 0; i < len; i++ ) {
		A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
	}

	aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );


	// LS reminder (term D(0,0) in C-syntax)
	double aprD = aprNu / sumVecCommon;

	// Aproximation of cov
	// the following code is very numerically sensitive, thus
	// we have to eliminate decompositions etc. as much as possible
	mat aprC = zeros ( dimLS, dimLS );
	for ( int i = 0; i < len; i++ ) {
		aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
		vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
		aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
	}

	// Construct GiW pdf :: BEGIN
	ldmat aprCinv ( inv ( aprC ) );
	vec D = concat ( aprD, aprCinv._D() );
	mat L = eye ( dimLS + 1 );
	L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
	L.set_submatrix ( 1, 1, aprCinv._L() );
	ldmat aprLD ( L, D );

	egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
	return aprgiw;
};

double mprod::evallogcond ( const vec &val, const vec &cond ) {
	int i;
	double res = 0.0;
	for ( i = pdfs.length() - 1; i >= 0; i-- ) {
		/*			if ( pdfs(i)->_rvc().count() >0) {
						pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
					}
					// add logarithms
					res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
		res += pdfs ( i )->evallogcond (
		           dls ( i )->pushdown ( val ),
		           dls ( i )->get_cond ( val, cond )
		       );
	}
	return res;
}

vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
	vec tmp ( Dt.cols() );
	for ( int i = 0; i < Dt.cols(); i++ ) {
		tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
	}
	return tmp;
}

vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
	vec tmp ( Dt.length() );
	for ( int i = 0; i < Dt.length(); i++ ) {
		tmp ( i ) = evallogcond ( Dt ( i ), cond );
	}
	return tmp;
}

void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
	pdfs = mFacs;
	dls.set_size ( mFacs.length() );

	rv = get_composite_rv ( pdfs, true );
	dim = rv._dsize();

	for ( int i = 0; i < pdfs.length(); i++ ) {
		RV rvx = pdfs ( i )->_rvc().subt ( rv );
		rvc.add ( rvx ); // add rv to common rvc
	}
	dimc = rvc._dsize();

	// rv and rvc established = > we can link them with pdfs
	for ( int i = 0; i < pdfs.length(); i++ ) {
		dls ( i ) = new datalink_m2m;
		dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
	}
}

void mmix::validate()
{	
	bdm_assert ( Coms.length() > 0, "There has to be at least one component." );

	bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );

	double sum_w = sum ( w );
	bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
	w = w / sum_w;

	int dim = Coms ( 0 )->dimension();
	int dimc = Coms ( 0 )->dimensionc();
	for ( int i = 1; i < Coms.length(); i++ ) {
		bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
		bdm_assert ( dimc == ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
		bdm_assert ( Coms(i)->isnamed() , "An unnamed component is forbidden here!" );
	}

	set_rv ( Coms ( 0 )->_rv() );
	set_rvc ( Coms ( 0 )->_rvc() );
}

void mmix::from_setting ( const Setting &set ) {
	UI::get ( Coms, set, "pdfs", UI::compulsory );

	if ( !UI::get ( w, set, "weights", UI::optional ) ) {
		int len = Coms.length();
		w.set_length ( len );
		w = 1.0 / len;
	}
}

vec mmix::samplecond ( const vec &cond ) {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
#pragma omp critical
	u0 = UniRNG.sample();

	int i = 0;
	while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
		i++;
	}

	return Coms ( i )->samplecond ( cond );
}

vec eprod::mean() const {
	vec tmp ( dim );
	for ( int i = 0; i < epdfs.length(); i++ ) {
		vec pom = epdfs ( i )->mean();
		dls ( i )->pushup ( tmp, pom );
	}
	return tmp;
}

vec eprod::variance() const {
	vec tmp ( dim ); //second moment
	for ( int i = 0; i < epdfs.length(); i++ ) {
		vec pom = epdfs ( i )->mean();
		dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
	}
	return tmp - pow ( mean(), 2 );
}
vec eprod::sample() const {
	vec tmp ( dim );
	for ( int i = 0; i < epdfs.length(); i++ ) {
		vec pom = epdfs ( i )->sample();
		dls ( i )->pushup ( tmp, pom );
	}
	return tmp;
}
double eprod::evallog ( const vec &val ) const {
	double tmp = 0;
	for ( int i = 0; i < epdfs.length(); i++ ) {
		tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
	}
	bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
	return tmp;
}

}
// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
// 		int i;
// 		bool rvaddok;
// 		// Create rv
// 		for ( i = 0;i < n;i++ ) {
// 			rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
// 			// If rvaddok==false, pdfs overlap => assert error.
// 			epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
// 		};
// 		// Create rvc
// 		for ( i = 0;i < n;i++ ) {
// 			rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
// 		};
//
// //		independent = true;
// 		//test rvc of pdfs and fill rvinds
// 		for ( i = 0;i < n;i++ ) {
// 			// find ith rv in common rv
// 			rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
// 			// find ith rvc in common rv
// 			rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
// 			// find ith rvc in common rv
// 			irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
// 			//
// /*			if ( rvcinrv ( i ).length() >0 ) {independent = false;}
// 			if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
// 		}
// 	};

