#include "emix.h"

namespace bdm{

void emix::set_parameters ( const vec &w0, const Array<epdf*> &Coms0, bool copy ) {
	w = w0/sum ( w0 );
	dim = Coms0(0)->dimension();
	int i;
	for ( i=0;i<w.length();i++ ) {
		it_assert_debug ( dim== ( Coms0 ( i )->dimension() ),"Component sizes do not match!" );
	}
	if ( copy ) {
		Coms.set_length(Coms0.length());
		for ( i=0;i<w.length();i++ ) {it_error("Not imp...");
			*Coms ( i ) =*Coms0 ( i );}
		destroyComs=true;
	}
	else {
		Coms = Coms0;
		destroyComs=false;
	}
}

vec emix::sample() const {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
	#pragma omp critical
	u0 = UniRNG.sample();

	int i=0;
	while ( ( cumDist ( i ) <u0 ) && ( i< ( w.length()-1 ) ) ) {i++;}

	return Coms ( i )->sample();
}

emix* emix::marginal(const RV &rv) const{
	it_assert_debug(isnamed(), "rvs are not assigned");
			
	Array<epdf*> Cn(Coms.length());
	for(int i=0;i<Coms.length();i++){Cn(i)=Coms(i)->marginal(rv);}
	emix* tmp = new emix();
	tmp->set_parameters(w,Cn,false);
	tmp->ownComs();
	return tmp;
}

mratio* emix::condition(const RV &rv) const{
	it_assert_debug(isnamed(), "rvs are not assigned");
	return new mratio(this,rv);
};

void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
	w = w0/sum ( w0 );
	dim = Coms0(0)->dimension();
	int i;
	for ( i=0;i<w.length();i++ ) {
		it_assert_debug ( dim== ( Coms0 ( i )->dimension() ),"Component sizes do not match!" );
	}
	if ( copy ) {
		Coms.set_length(Coms0.length());
		for ( i=0;i<w.length();i++ ) {it_error("Not imp...");
			*Coms ( i ) =*Coms0 ( i );}
		destroyComs=true;
	}
	else {
		Coms = Coms0;
		destroyComs=false;
	}
}

vec egiwmix::sample() const {
	//Sample which component
	vec cumDist = cumsum ( w );
	double u0;
	#pragma omp critical
	u0 = UniRNG.sample();

	int i=0;
	while ( ( cumDist ( i ) <u0 ) && ( i< ( w.length()-1 ) ) ) {i++;}

	return Coms ( i )->sample();
}

vec egiwmix::mean() const {
	int i; vec mu = zeros ( dim );
	for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
	return mu;
}

vec egiwmix::variance() const {
	// non-central moment
	vec mom2 = zeros ( dim );
	for ( int i = 0;i < w.length();i++ ) {
		// pow is overloaded, we have to use another approach
		mom2 += w ( i ) * (Coms(i)->variance() + elem_mult ( Coms(i)->mean(), Coms(i)->mean() )); 
	}
	// central moment
	// pow is overloaded, we have to use another approach
	return mom2 - elem_mult( mean(), mean() );
}

emix* egiwmix::marginal(const RV &rv) const{
	it_assert_debug(isnamed(), "rvs are not assigned");
			
	Array<epdf*> Cn(Coms.length());
	for(int i=0;i<Coms.length();i++){Cn(i)=Coms(i)->marginal(rv);}
	emix* tmp = new emix();
	tmp->set_parameters(w,Cn,false);
	tmp->ownComs();
	return tmp;
}

egiw* 	egiwmix::approx() {
	// NB: dimx == 1 !!!
	// The following code might look a bit spaghetti-like,
	// consult Dedecius, K. et al.: Partial forgetting in AR models.

	double sumVecCommon;			      	// common part for many terms in eq.
	int len = w.length();				// no. of mix components	
	int dimLS = Coms(1)->_V()._D().length() - 1; 	// dim of LS
	vec vecNu(len);					// vector of dfms of components
	vec vecD(len);					// vector of LS reminders of comps.
	vec vecCommon(len);				// vector of common parts
	mat matVecsTheta;				// matrix which rows are theta vects.

	// fill in the vectors vecNu, vecD and matVecsTheta
	for ( int i=0; i<len; i++ ) {
		vecNu.shift_left( Coms(i)->_nu() );
		vecD.shift_left( Coms(i)->_V()._D()(0) );
		matVecsTheta.append_row( Coms(i)->est_theta()  );
	}

	// calculate the common parts and their sum
	vecCommon = elem_mult ( w, elem_div(vecNu, vecD) );
	sumVecCommon = sum(vecCommon);

	// LS estimator of theta
	vec aprEstTheta(dimLS);  aprEstTheta.zeros();
	for ( int i=0; i<len; i++ ) {
		aprEstTheta +=  matVecsTheta.get_row( i ) * vecCommon ( i );
	}
	aprEstTheta /= sumVecCommon;
	
	
	// LS estimator of dfm
	double aprNu;
	double A = log( sumVecCommon );		// Term 'A' in equation

	for ( int i=0; i<len; i++ ) {
		A += w(i) * ( log( vecD(i) ) - psi( 0.5 * vecNu(i) ) );
	}

	aprNu = ( 1 + sqrt(1 + 2 * (A - LOG2)/3 ) ) / ( 2 * (A - LOG2) );


	// LS reminder (term D(0,0) in C-syntax)
	double aprD = aprNu / sumVecCommon;

	// Aproximation of cov
	// the following code is very numerically sensitive, thus
	// we have to eliminate decompositions etc. as much as possible
	mat aprC = zeros(dimLS, dimLS);
	for ( int i=0; i<len; i++ ) {
		aprC += Coms(i)->est_theta_cov().to_mat() * w(i); 
		vec tmp = ( matVecsTheta.get_row(i) - aprEstTheta );
		aprC += vecCommon(i) * outer_product( tmp, tmp);
	}

	// Construct GiW pdf :: BEGIN
	ldmat aprCinv ( inv(aprC) );
	vec D = concat( aprD, aprCinv._D() );
	mat L = eye(len+1);
	L.set_submatrix(1,0, aprCinv._L() * aprEstTheta);
	L.set_submatrix(1,1, aprCinv._L());
	ldmat aprLD (L, D);

	egiw* aprgiw = new egiw(1, aprLD, aprNu);
	return aprgiw;
};

}
// mprod::mprod ( Array<mpdf*> mFacs, bool overlap) : mpdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), mpdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
// 		int i;
// 		bool rvaddok;
// 		// Create rv
// 		for ( i = 0;i < n;i++ ) {
// 			rvaddok=rv.add ( mpdfs ( i )->_rv() ); //add rv to common rvs.
// 			// If rvaddok==false, mpdfs overlap => assert error.
// 			it_assert_debug(rvaddok||overlap,"mprod::mprod() input mpdfs overlap in rv!");
// 			epdfs ( i ) = & ( mpdfs ( i )->posterior() ); // add pointer to epdf
// 		};
// 		// Create rvc
// 		for ( i = 0;i < n;i++ ) {
// 			rvc.add ( mpdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
// 		};
//
// //		independent = true;
// 		//test rvc of mpdfs and fill rvinds
// 		for ( i = 0;i < n;i++ ) {
// 			// find ith rv in common rv
// 			rvsinrv ( i ) = mpdfs ( i )->_rv().dataind ( rv );
// 			// find ith rvc in common rv
// 			rvcinrv ( i ) = mpdfs ( i )->_rvc().dataind ( rv );
// 			// find ith rvc in common rv
// 			irvcs_rvc ( i ) = mpdfs ( i )->_rvc().dataind ( rvc );
// 			//
// /*			if ( rvcinrv ( i ).length() >0 ) {independent = false;}
// 			if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
// 		}
// 	};
