#include "shared_ptr.h"
#include "stat/exp_family.h"
#include "stat/emix.h"
#include "mat_checks.h"
#include "UnitTest++.h"
#include "test_util.h"

const double epsilon = 0.00001;

using namespace bdm;

TEST ( test_emix_old ) {
	RV x ( "{x }" );
	RV y ( "{y }" );
	RV xy = concat ( x, y );

	shared_ptr<enorm<ldmat> > E1 = new enorm<ldmat>();
	E1->set_rv ( xy );
	E1->set_parameters ( "1.00054 1.0455" , mat ( "0.740142 -0.259015; -0.259015 1.0302" ) );

	shared_ptr<enorm<ldmat> > E2 = new enorm<ldmat>();
	E2->set_rv ( xy );
	E2->set_parameters ( "-1.2 -0.1" , mat ( "1 0.4; 0.4 0.5" ) );

	Array<shared_ptr<epdf> > A1 ( 1 );
	A1 ( 0 ) = E1;

	emix M1;
	M1.set_rv ( xy );
	M1.set_parameters ( vec ( "1" ), A1 );

	// test if ARX and emix with one ARX are the same
	shared_ptr<epdf> Mm = M1.marginal ( y );
	shared_ptr<epdf> Am = E1->marginal ( y );
	shared_ptr<mpdf> Mc = M1.condition ( y );
	shared_ptr<mpdf> Ac = E1->condition ( y );

	mlnorm<ldmat> *wacnd = dynamic_cast<mlnorm<ldmat> *>( Ac.get() );
	CHECK(wacnd);
	if ( wacnd ) {
		CHECK_CLOSE ( mat ( "-0.349953" ), wacnd->_A(), epsilon );
		CHECK_CLOSE ( vec ( "1.39564" ), wacnd->_mu_const(), epsilon );
		CHECK_CLOSE ( mat ( "0.939557" ), wacnd->_R(), epsilon );
	}

	double same = -1.46433;
	CHECK_CLOSE ( same, Mm->evallog ( vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( same, Am->evallog ( vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( 0.145974, Mc->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );
	CHECK_CLOSE ( -1.92433, Ac->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );

	// mixture with two components
	Array<shared_ptr<epdf> > A2 ( 2 );
	A2 ( 0 ) = E1;
	A2 ( 1 ) = E2;

	emix M2;
	M2.set_rv ( xy );
	M2.set_parameters ( vec ( "1" ), A2 );


	// mixture normalization
	CHECK_CLOSE ( 1.0, normcoef ( &M2, vec ( "-3 3 " ), vec ( "-3 3 " ) ), 0.1 );

	int N = 3;
	vec ll2 ( N );
	mat Smp = M2.sample_m ( N );
	vec ll = M2.evallog_m ( Smp );

	vec Emu = sum ( Smp, 2 ) / N;
	CHECK_CLOSE ( vec ( "1.00054 1.0455" ), Emu, 1.0 );

	mat Er = ( Smp * Smp.transpose() ) / N - outer_product ( Emu, Emu );
	CHECK_CLOSE ( mat ( "0.740142 -0.259015; -0.259015 1.0302" ), Er, 2.0 );

	shared_ptr<epdf> Mg = M2.marginal ( y );
	CHECK ( Mg.get() );
	shared_ptr<mpdf> Cn = M2.condition ( x );
	CHECK ( Cn.get() );

	// marginal mean
	CHECK_CLOSE ( vec ( "1.0" ), Mg->mean(), 0.1 );
}
