#include <itpp/itbase.h>
#include "merger.h"
#include "arx.h"

vec merger::lognorm_merge ( mat &lW ) {
	int nu=lW.rows();
	vec mu = sum ( lW ) /nu; //mean of logs
	vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
	double coef=0.0;
	switch ( nu ) {
		case 2:
			coef=sqrt ( beta*2 ) * ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) );
			return exp ( coef*sqrt ( lam ) + mu );
			break;
		case 3://Ration of Bessel
			break;
		case 4:
			break;
		default: // Approximate conditional density
			break;
	}
	return vec ( 0 );
}

void merger::merge ( const epdf* g0 ) {
	it_file dbg ( "merger_debug.it" );

	it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
	//Empirical density - samples
	eEmp eSmp ( rv,Ns );
	eSmp.set_parameters ( ones ( Ns ), g0 );
	Array<vec> &Smp = eSmp._samples(); //aux
	vec &w = eSmp._w(); //aux

	mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
	for ( int i=0;i<Ns;i++ ) {	set_col_part ( Smp_ex,i,Smp ( i ) );}

	dbg << Name ( "Smp_0" ) << Smp_ex;

	// Stuff for merging
	vec lw_src ( Ns );
	vec lw_mix ( Ns );
	mat lW=zeros ( n,Ns );
	vec vec0 ( 0 );

	// Initial component in the mixture model
	mat V0=1e-8*eye ( rv.count() +1 );
	ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
	         V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance

	Mix.init ( &A0, Smp_ex, Nc );
	//Preserve initial mixture for repetitive estimation via flattening
	MixEF Mix_init(Mix);

	// ============= MAIN LOOP ==================
	bool converged=false;
	int niter = 0;
	char str[100];

	epdf* Mpred;
	vec Mix_pdf ( Ns );
	while ( !converged ) {
		//Re-estimate Mix
		//Re-Initialize Mixture model
		Mix.flatten(&Mix_init);
		Mix.bayesB ( Smp_ex, w*Ns );
		Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!

		sprintf ( str,"Mpred_mean%d",niter );
		dbg << Name ( str ) << Mpred->mean();

		// Generate new samples
		eSmp.set_samples ( Mpred );
		for ( int i=0;i<Ns;i++ ) {	set_col_part ( Smp_ex,i,Smp ( i ) );}

		sprintf ( str,"Mpdf%d",niter );
		for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
		dbg << Name ( str ) << Mix_pdf;

		sprintf ( str,"Smp%d",niter );
		dbg << Name ( str ) << Smp_ex;

		//Importace weighting
		for ( int i=0;i<n;i++ ) {
			lw_src=0.0;
			//======== Same RVs ===========
			//Split according to dependency in rvs
			if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
				// no need for conditioning or marginalization
				for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
					lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
				}
			}
			else {
				// compute likelihood of marginal on the conditional variable
				if ( mpdfs ( i )->_rvc().count() >0 ) {
					// Make marginal on rvc_i
					epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
					//compute vector of lw_src
					for ( int k=0;k<Ns;k++ ) {
						lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_val ( Smp ( i ) ) );
					}
					delete tmp_marg;
				}
				// Compute likelihood of the missing variable
				if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
					///////////////
					// There are variales unknown to mpdfs(i) : rvzs
					mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
					// Compute likelihood
					vec lw_dbg=lw_src;
					for ( int k= 0; k<Ns; k++ ) {
						lw_src ( k ) += log (
						                    tmp_cond->evalcond (
						                        zdls ( i )->get_val ( Smp ( k ) ),
						                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
					}
					delete tmp_cond;
				}
				// Compute likelihood of the partial source
				for ( int k= 0; k<Ns; k++ ) {
					mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
					lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp ( k ) ) );
				}

			}
			lW.set_row ( i, lw_src ); // do not divide by mix
		}
		//Importance of the mixture
		for ( int j=0;j<Ns;j++ ) {
			lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
		}
		sprintf ( str,"lW%d",niter );
		dbg << Name ( str ) << lW;

		w = lognorm_merge ( lW ); //merge

		sprintf ( str,"w%d",niter );
		dbg << Name ( str ) << w;
		sprintf ( str,"lw_m%d",niter );
		dbg << Name ( str ) << lw_mix;

		//Importance weighting
		w /=exp ( lw_mix ); // hoping that it is not numerically sensitive...
		//renormalize
		w /=sum ( w );

		sprintf ( str,"w_is_%d",niter );
		dbg << Name ( str ) << w;

// 		eSmp.resample(); // So that it can be used in bayes
// 		for ( int i=0;i<Ns;i++ ) {	set_col_part ( Smp_ex,i,Smp ( i ) );}

		sprintf ( str,"Smp_res%d",niter );
		dbg << Name ( str ) << Smp;

		// ==== stopping rule ===
		niter++;
		converged = ( niter>9 );
	}

}

