#include "stat/exp_family.h"
#include "stat/emix.h"
#include "../mat_checks.h"
#include "UnitTest++.h"
#include "../test_util.h"
#include "../pdf_harness.h"


const double epsilon = 0.00001;

using namespace bdm;

static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance );

static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance );

TEST ( emix_test ) {
	pdf_harness::test_config ( "emix.cfg" );
}

TEST ( emix_1_test ) {
	RV x ( "{emixx }" );
	RV y ( "{emixy }" );
	RV xy = concat ( x, y );
	vec mu0 ( "1.00054 1.0455" );

	enorm_ldmat_ptr E1;
	E1->set_rv ( xy );
	E1->set_parameters ( mu0 , mat ( "0.740142 -0.259015; -0.259015 1.0302" ) );

	enorm_ldmat_ptr E2;
	E2->set_rv ( xy );
	E2->set_parameters ( "-1.2 -0.1" , mat ( "1 0.4; 0.4 0.5" ) );

	epdf_array A1 ( 1 );
	A1 ( 0 ) = E1;

	emix M1;
	M1.set_rv ( xy );
	M1._Coms() = A1;
	M1._w() = vec_1(1.0);
	M1.validate();

	// test if ARX and emix with one ARX are the same
	epdf_ptr Mm = M1.marginal ( y );
	epdf_ptr Am = E1->marginal ( y );
	pdf_ptr Mc = M1.condition ( y );
	pdf_ptr Ac = E1->condition ( y );

	mlnorm<ldmat> *wacnd = dynamic_cast<mlnorm<ldmat> *> ( Ac.get() );
	CHECK ( wacnd );
	if ( wacnd ) {
		CHECK_CLOSE ( mat ( "-0.349953" ), wacnd->_A(), epsilon );
		CHECK_CLOSE ( vec ( "1.39564" ), wacnd->_mu_const(), epsilon );
		CHECK_CLOSE ( mat ( "0.939557" ), wacnd->_R(), epsilon );
	}

	double same = -1.46433;
	CHECK_CLOSE ( same, Mm->evallog ( vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( same, Am->evallog ( vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( 0.145974, Mc->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( -1.92433, Ac->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );

	// mixture with two components
	epdf_array A2 ( 2 );
	A2 ( 0 ) = E1;
	A2 ( 1 ) = E2;

	emix M2;
	M2.set_rv ( xy );
	M2._Coms() = A2;
	M2._w() = vec_2(.5,.5);
	M2.validate();

	// mixture normalization
	CHECK_CLOSE ( 1.0, normcoef ( &M2, vec ( "-3 3 " ), vec ( "-3 3 " ) ), 0.1 );

	int N = 6;
	mat Smp = M2.sample_mat ( N );

	vec exp_ll ( "-5.0 -2.53563 -2.62171 -5.0 -2.53563 -2.62171" );
	vec ll = M2.evallog_mat ( Smp );
	CHECK_CLOSE ( exp_ll, ll, 5.0 );

	check_mean ( M2, N, 0.5*mu0+0.5*vec("-1.2 -0.1"), 1.0 );

	mat observedR ( "0.740142 -0.259015; -0.259015 1.0302" );
	check_covariance ( M2, N, observedR, 2.0 );

	epdf_ptr Mg = M2.marginal ( y );
	CHECK ( Mg.get() );
	pdf_ptr Cn = M2.condition ( x );
	CHECK ( Cn.get() );

	// marginal mean
	CHECK_CLOSE ( vec ( "0.5" ), Mg->mean(), 0.1 );
}


static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ) {
	int tc = 0;
	Array<vec> actual ( CurrentContext::max_trial_count );
	do {
		mat smp = distrib_obj.sample_mat ( nsamples );
		vec emu = sum ( smp, 2 ) / nsamples;
		actual ( tc ) = emu;
		++tc;
	} while ( ( tc < CurrentContext::max_trial_count ) &&
	          !UnitTest::AreClose ( mean, actual ( tc - 1 ), tolerance ) );
	if ( ( tc == CurrentContext::max_trial_count ) &&
	        ( !UnitTest::AreClose ( mean, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
		UnitTest::MemoryOutStream stream;
		UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), __LINE__ );
		stream << "Expected " << mean << " +/- " << tolerance << " but was " << actual;

		UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
	}
}

static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance ) {
	int tc = 0;
	Array<mat> actual ( CurrentContext::max_trial_count );
	do {
		mat smp = distrib_obj.sample_mat ( nsamples );
		vec emu = sum ( smp, 2 ) / nsamples;
		mat er = ( smp * smp.T() ) / nsamples - outer_product ( emu, emu );
		actual ( tc ) = er;
		++tc;
	} while ( ( tc < CurrentContext::max_trial_count ) &&
	          !UnitTest::AreClose ( R, actual ( tc - 1 ), tolerance ) );
	if ( ( tc == CurrentContext::max_trial_count ) &&
	        ( !UnitTest::AreClose ( R, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
		UnitTest::MemoryOutStream stream;
		UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), __LINE__ );
		stream << "Expected " << R << " +/- " << tolerance << " but was " << actual;

		UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
	}
}
