#include "../bdm/math/square_mat.h"
#include "../bdm/math/chmat.h"
#include "base/user_info.h"
#include "../square_mat_point.h"
#include "UnitTest++.h"
#include "TestReporterStdout.h"
#include <iostream>
#include <iomanip>
#include <stdlib.h>
#include <string.h>

using std::cout;
using std::cerr;
using std::endl;

using bdm::fsqmat;
using bdm::chmat;
using bdm::ldmat;
using bdm::shared_ptr;
using bdm::UIFile;
using bdm::UI;

const char *agenda_file_name = "agenda.cfg";
double epsilon = 0.00001;
bool fast = false;

namespace UnitTest {

// can't include mat_checks.h because CheckClose is different in this file
extern bool AreClose ( const itpp::vec &expected, const itpp::vec &actual,
                       double tolerance );

extern bool AreClose ( const itpp::mat &expected, const itpp::mat &actual,
                       double tolerance );

void CheckClose ( TestResults &results, const itpp::mat &expected,
                  const itpp::mat &actual, double tolerance,
                  TestDetails const& details ) {
	if ( !AreClose ( expected, actual, tolerance ) ) {
		MemoryOutStream stream;
		stream << "failed at " << expected.rows()
		<< " x " << expected.cols();

		results.OnTestFailure ( details, stream.GetText() );
	}
}

}

typedef void ( *FTestMatrix ) ( int, square_mat_point * );

template<typename TMatrix>
void test_matrix ( int index, square_mat_point *point ) {
	Real_Timer tt;

	cout << "agenda[" << index << "]:" << endl;
	mat A = point->get_matrix();
	int sz = A.rows();
	CHECK_EQUAL ( A.cols(), sz );

	tt.tic();
	TMatrix sq_mat ( A );
	double elapsed = tt.toc();
	cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;

	tt.tic();
	mat res = sq_mat.to_mat();
	elapsed = tt.toc();

	if ( !fast ) {
		CHECK_CLOSE ( A, res, epsilon );
	}

	cout << "to_mat: " << elapsed << " s" << endl;

	vec v = point->get_vector();
	double w = point->get_scalar();
	TMatrix sq_mat2 = sq_mat;

	tt.tic();
	sq_mat2.opupdt ( v, w );
	elapsed = tt.toc();

	if ( !fast ) {
		mat expA = A + w * outer_product ( v, v );
		CHECK_CLOSE ( expA, sq_mat2.to_mat(), epsilon );
	}

	cout << "opupdt: " << elapsed << " s" << endl;

	TMatrix invmat ( sz );

	tt.tic();
	sq_mat.inv ( invmat );
	elapsed = tt.toc();

	mat invA;
	if ( !fast ) {
		invA = inv ( A );
		CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );
	}

	cout << "inv: " << elapsed << " s" << endl;

	tt.tic();
	double ld = sq_mat.logdet();
	elapsed = tt.toc();

	if ( !fast ) {
		double d = det ( A );
		CHECK_CLOSE ( log ( d ), ld, epsilon );
	}

	cout << "logdet: " << elapsed << " s" << endl;

	tt.tic();
	double q = sq_mat.qform ( ones ( sz ) );
	elapsed = tt.toc();

	if ( !fast ) {
		CHECK_CLOSE ( sumsum ( A ), q, epsilon );
	}

	cout << "qform(1): " << elapsed << " s" << endl;

	tt.tic();
	q = sq_mat.qform ( v );
	elapsed = tt.toc();

	if ( !fast ) {
		double r = ( A * v ) * v;
		CHECK_CLOSE ( r, q, epsilon );
	}

	cout << "qform(v): " << elapsed << " s" << endl;

	tt.tic();
	q = sq_mat.invqform ( v );
	elapsed = tt.toc();

	if ( !fast ) {
		double r = ( invA * v ) * v;
		CHECK_CLOSE ( r, q, epsilon );
	}

	cout << "invqform: " << elapsed << " s" << endl;

	TMatrix twice = sq_mat;

	tt.tic();
	twice += sq_mat;
	elapsed = tt.toc();

	if ( !fast ) {
		res = 2 * A;
		CHECK_CLOSE ( res, twice.to_mat(), epsilon );
	}

	cout << "+=: " << elapsed << " s" << endl;

	sq_mat2 = sq_mat;

	tt.tic();
	sq_mat2.mult_sym ( A );
	elapsed = tt.toc();

	if ( !fast ) {
		res = ( A * A ) * A.T();
		CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon );
	}

	cout << "^2: " << elapsed << " s" << endl;
}

void test_agenda ( FTestMatrix test ) {
	UIFile fag ( agenda_file_name );
	Array<shared_ptr<square_mat_point> > mag;
	UI::get ( mag, fag, "agenda", UI::compulsory );
	int sz = mag.size();
	CHECK ( sz > 0 );
	for ( int i = 0; i < sz; ++i ) {
		test ( i, mag ( i ).get() );
	}
}

SUITE ( ldmat ) {
	TEST ( agenda ) {
		test_agenda ( test_matrix<ldmat> );
	}
}

SUITE ( fsqmat ) {
	TEST ( agenda ) {
		test_agenda ( test_matrix<fsqmat> );
	}
}

SUITE ( chmat ) {
	TEST ( agenda ) {
		test_agenda ( test_matrix<chmat> );
	}
}

int main ( int argc, char const *argv[] ) {
	bool unknown = false;
	int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
	const char *suite = "ldmat";
	const char **param = argv + 1;
	while ( *param && !unknown ) {
		if ( update_next ) {
			if ( update_next == 1 ) {
				suite = *param;
			} else if ( update_next == 2 ) {
				double eps = atof ( *param );
				if ( eps > 0 ) {
					epsilon = eps;
				} else {
					cerr << "invalid epsilon value ignored" << endl;
				}
			} else {
				agenda_file_name = *param;
			}

			update_next = 0;
		} else {
			if ( !strcmp ( *param, "-a" ) ) {
				update_next = 3;
			} else if ( !strcmp ( *param, "-c" ) ) {
				update_next = 1;
			} else if ( !strcmp ( *param, "-e" ) ) {
				update_next = 2;
			} else if ( !strcmp ( *param, "-f" ) ) {
				fast = true;
			} else {
				unknown = true;
			}
		}

		++param;
	}

	if ( unknown || update_next ) {
		cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
	} else {
		UnitTest::TestReporterStdout reporter;
		UnitTest::TestRunner runner ( reporter );
		return runner.RunTestsIf ( UnitTest::Test::GetTestList(),
		                           suite,
		                           UnitTest::True(),
		                           0 );
	}
}
