
#include "merger.h"
#include "arx.h"

namespace bdm
{
	vec merger::lognorm_merge ( mat &lW )
	{
		int nu=lW.rows();
		vec mu = sum ( lW ) /nu; //mean of logs
		// vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
		vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
		double coef=0.0;
		vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
		switch ( nu )
		{
			case 2:
				coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
				return  coef*sq2bl + mu ;
				break;
			case 3://Ratio of Bessel
				coef = sqrt ( ( 3*beta-2 ) /3*beta );
				return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
				break;
			case 4:
				break;
			default: // Approximate conditional density
				break;
		}
		return vec ( 0 );
	}

	void merger::merge ( const epdf* g0 )
	{

		it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
		//Empirical density - samples
		if ( !fix_smp )
		{
			eSmp.set_statistics ( ones ( Ns ), g0 );
		}

		Array<vec> &Smp = eSmp._samples(); //aux
		vec &w = eSmp._w(); //aux

		mat Smp_ex =ones ( dim +1,Ns ); // Extended samples for the ARX model - the last row is ones
		for ( int i=0;i<Ns;i++ ) {	set_col_part ( Smp_ex,i,Smp ( i ) );}

		if ( DBG )	*dbg << Name ( "Smp_0" ) << Smp_ex;

		// Stuff for merging
		vec lw_src ( Ns );
		vec lw_mix ( Ns );
		vec lw ( Ns );
		mat lW=zeros ( n,Ns );
		vec vec0 ( 0 );

		//initialize importance weights
		if ( !fix_smp )
			for ( int i=0;i<Ns;i++ )
			{
				lw_mix ( i ) =g0->evallog ( Smp ( i ) );
			}

		// Initial component in the mixture model
		mat V0=1e-8*eye ( dim +1 );
		ARX A0;	A0.set_statistics ( dim, V0 ); //initial guess of Mix:

		Mix.init ( &A0, Smp_ex, Nc );
		//Preserve initial mixture for repetitive estimation via flattening
		MixEF Mix_init ( Mix );

		// ============= MAIN LOOP ==================
		bool converged=false;
		int niter = 0;
		char str[100];

		emix* Mpred=Mix.epredictor ( );
		vec Mix_pdf ( Ns );
		while ( !converged )
		{
			//Re-estimate Mix
			//Re-Initialize Mixture model
			Mix.flatten ( &Mix_init );
			Mix.bayesB ( Smp_ex, w*Ns );
			delete Mpred;
			Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
			Mpred->set_rv ( rv ); //the predictor predicts rv of this merger

			// This will be active only later in iterations!!!
			if ( ( !fix_smp ) & ( 1./sum_sqr ( w ) <effss_coef*Ns ) )
			{
				// Generate new samples
				eSmp.set_samples ( Mpred );
				for ( int i=0;i<Ns;i++ )
				{
					//////////// !!!!!!!!!!!!!
					if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
					set_col_part ( Smp_ex,i,Smp ( i ) );
					//Importance of the mixture
					//lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
					lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
				}
				if ( 0 )
				{
					cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
					cout << Mix._e()->mean() <<endl;
					cout << sum ( Smp_ex,2 ) /Ns <<endl;
					cout << Smp_ex*Smp_ex.T() /Ns << endl;
				}
			}
			if ( DBG )
			{
				sprintf ( str,"Mpred_mean%d",niter );
				*dbg << Name ( str ) << Mpred->mean();
				sprintf ( str,"Mpred_var%d",niter );
				*dbg << Name ( str ) << Mpred->variance();


				sprintf ( str,"Mpdf%d",niter );
				for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
				*dbg << Name ( str ) << Mix_pdf;

				sprintf ( str,"Smp%d",niter );
				*dbg << Name ( str ) << Smp_ex;

			}
			//Importace weighting
			for ( int i=0;i<n;i++ )
			{
				lw_src=0.0;
				//======== Same RVs ===========
				//Split according to dependency in rvs
				if ( mpdfs ( i )->dimension() ==dim )
				{
					// no need for conditioning or marginalization
					for ( int j=0;j<Ns; j++ )   // Smp is Array<> => for cycle
					{
						lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) );
					}
				}
				else
				{
					// compute likelihood of marginal on the conditional variable
					if ( mpdfs ( i )->dimensionc() >0 )
					{
						// Make marginal on rvc_i
						epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
						//compute vector of lw_src
						for ( int k=0;k<Ns;k++ )
						{
							// Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
							lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
						}
						delete tmp_marg;

// 					sprintf ( str,"marg%d",niter );
// 					*dbg << Name ( str ) << lw_src;

					}
					// Compute likelihood of the missing variable
					if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
					{
						///////////////
						// There are variales unknown to mpdfs(i) : rvzs
						mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
						// Compute likelihood
						vec lw_dbg=lw_src;
						for ( int k= 0; k<Ns; k++ )
						{
							lw_src ( k ) += log (
							                    tmp_cond->evallogcond (
							                        zdls ( i )->pushdown ( Smp ( k ) ),
							                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
							if ( !std::isfinite ( lw_src ( k ) ) )
							{
								lw_src ( k ) = -1e16; cout << "!";
							}
						}
						delete tmp_cond;
					}
					// Compute likelihood of the partial source
					for ( int k= 0; k<Ns; k++ )
					{
						mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
						lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
					}

				}
//			it_assert_debug(std::isfinite(sum(lw_src)),"bad");
				lW.set_row ( i, lw_src ); // do not divide by mix
			}
			lw = lognorm_merge ( lW ); //merge

			//Importance weighting
			if ( !fix_smp )
				lw -=  lw_mix; // hoping that it is not numerically sensitive...
			w = exp ( lw-max ( lw ) );

			//renormalize
			double sumw =sum ( w );
			if ( std::isfinite ( sumw ) )
			{
				w = w/sumw;
			}
			else
			{
				it_file itf ( "merg_err.it" );
				itf << Name ( "w" ) << w;
			}

			if ( DBG )
			{
				sprintf ( str,"lW%d",niter );
				*dbg << Name ( str ) << lW;
				sprintf ( str,"w%d",niter );
				*dbg << Name ( str ) << w;
				sprintf ( str,"lw_m%d",niter );
				*dbg << Name ( str ) << lw_mix;
			}
			// ==== stopping rule ===
			niter++;
			converged = ( niter>40 );
		}
		delete Mpred;
//		cout << endl;

	}

}
