#include "base/bdmbase.h"
#include "base/user_info.h"
#include "stat/emix.h"
#include "itpp_ext.h"
#include "mpdf_harness.h"
#include "mat_checks.h"
#include "UnitTest++.h"

using namespace bdm;

static void check_mean(mmix &mMix, const vec &mu0, int nsamples, const vec &mean, double tolerance);

static void check_covariance(mmix &mMix, const vec &mu0, int nsamples, const mat &R, double tolerance);

TEST ( test_mepdf ) {
	mpdf_harness::test_config ( "mepdf.cfg" );
}

TEST ( test_mgamma ) {
	mpdf_harness::test_config ( "mgamma.cfg" );
}

TEST ( test_mlnorm ) {
	mpdf_harness::test_config ( "mlnorm.cfg" );
}

TEST ( test_mprod ) {
	mpdf_harness::test_config ( "mprod.cfg" );
}

// not using mpdf_harness because mmix isn't configurable (yet?)
TEST ( test_mmix ) {
	RV x ( "{mmixx }", "2" );
	RV y ( "{mmixy }", "2" );
	int N = 10000; //number of samples
	vec mu0 ( "1.5 1.7" );
	mat V0 ( "1.2 0.3; 0.3 5" );
	ldmat R = ldmat ( V0 );

	shared_ptr<enorm<ldmat> > eN = new enorm<ldmat>();
	eN->set_parameters ( mu0, R );

	shared_ptr<mgamma> mG = new mgamma();
	double k = 10.0;
	mG->set_parameters ( k, mu0 );

	mmix mMix;
	Array<shared_ptr<mpdf> > mComs ( 2 );

	// mmix::set_parameters requires the first mpdf to be named
	mG->set_rv(x);
	mG->set_rvc(y);
	mComs ( 0 ) = mG;

	eN->set_mu ( vec_2 ( 0.0, 0.0 ) );
	shared_ptr<mepdf> mEnorm = new mepdf ( eN );
	mComs ( 1 ) = mEnorm;

	mMix.set_parameters ( vec_2 ( 0.5, 0.5 ), mComs );

	double tolerance = 0.1;

	vec tmu = 0.5 * eN->mean() + 0.5 * mu0;
	check_mean ( mMix, mu0, N, tmu, tolerance );

	mat observedR ( "1.27572 0.778247; 0.778247 3.33129" );
	check_covariance( mMix, mu0, N, observedR, tolerance);
}

static void check_mean(mmix &mMix, const vec &mu0, int nsamples, const vec &mean, double tolerance) {
	int tc = 0;
	Array<vec> actual(CurrentContext::max_trial_count);
	do {
		mat smp = mMix.samplecond_m ( mu0, nsamples );
		vec emu = smp * ones ( nsamples ) / nsamples ;
		actual( tc ) = emu;
		++tc;
	} while ( ( tc < CurrentContext::max_trial_count ) &&
		  !UnitTest::AreClose ( mean, actual( tc - 1 ), tolerance ) );
	if ( ( tc == CurrentContext::max_trial_count ) &&
	     ( !UnitTest::AreClose ( mean, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
		UnitTest::MemoryOutStream stream;
		UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__);
		stream << "Expected " << mean << " +/- " << tolerance << " but was " << actual;

		UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
	}
}

static void check_covariance(mmix &mMix, const vec &mu0, int nsamples, const mat &R, double tolerance) {
	int tc = 0;
	Array<mat> actual(CurrentContext::max_trial_count);
	do {
		mat smp = mMix.samplecond_m ( mu0, nsamples );
		vec emu = smp * ones ( nsamples ) / nsamples ;
		mat er = ( smp * smp.T() ) / nsamples - outer_product ( emu, emu );
		actual( tc ) = er;
		++tc;
	} while ( ( tc < CurrentContext::max_trial_count ) &&
		  !UnitTest::AreClose ( R, actual( tc - 1 ), tolerance ) );
	if ( ( tc == CurrentContext::max_trial_count ) &&
	     ( !UnitTest::AreClose ( R, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
		UnitTest::MemoryOutStream stream;
		UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__);
		stream << "Expected " << R << " +/- " << tolerance << " but was " << actual;

		UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
       }
}
