#include "../bdm/math/square_mat.h"
#include "../bdm/math/chmat.h"
#include "UnitTest++.h"
#include "TestReporterStdout.h"
#include <iostream>
#include <iomanip>
#include <stdlib.h>
#include <string.h>

using std::cout;
using std::cerr;
using std::endl;

double epsilon = 0.00001;

bool fast = false;

namespace UnitTest
{

// can't include mat_checks.h because CheckClose is different in this file
extern bool AreClose(const itpp::vec &expected, const itpp::vec &actual,
		     double tolerance);

extern bool AreClose(const itpp::mat &expected, const itpp::mat &actual,
		     double tolerance);

void CheckClose(TestResults &results, const itpp::mat &expected,
		const itpp::mat &actual, double tolerance,
		TestDetails const& details) {
    if (!AreClose(expected, actual, tolerance)) { 
        MemoryOutStream stream;
        stream << "failed at " << expected.rows()
	       << " x " << expected.cols();

        results.OnTestFailure(details, stream.GetText());
    }
}

}

template<typename TMatrix>
void test_until_overflow() {
    Real_Timer tt;
    int sz = 7;
    while (true) {
	mat A0 = randu(sz, sz);
	mat A = A0 * A0.T();
	
	tt.tic();
	TMatrix sqmat(A);
	double elapsed = tt.toc();
	cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;

	tt.tic();
	mat res = sqmat.to_mat();
	elapsed = tt.toc();

	if (!fast) {
	    CHECK_CLOSE(A, res, epsilon);
	}

	cout << "to_mat: " << elapsed << " s" << endl;

	vec v = randu(sz);
	double w = randu();
	TMatrix sqmat2 = sqmat;
	
	tt.tic();
	sqmat2.opupdt(v, w);
	elapsed = tt.toc();

	if (!fast) {
	    mat expA = A + w * outer_product(v, v);
	    CHECK_CLOSE(expA, sqmat2.to_mat(), epsilon);
	}

	cout << "opupdt: " << elapsed << " s" << endl;

	TMatrix invmat(sz);

	tt.tic();
	sqmat.inv(invmat);
	elapsed = tt.toc();

	if (!fast) {
	    mat invA = inv(A);
	    CHECK_CLOSE(invA, invmat.to_mat(), epsilon);
	}

	cout << "inv: " << elapsed << " s" << endl;

	sz *= 7;
    }
}

SUITE(ldmat) {
    TEST(cycle) {
        test_until_overflow<ldmat>();
    }
}

SUITE(fsqmat) {
    TEST(cycle) {
        test_until_overflow<fsqmat>();
    }
}

SUITE(chmat) {
    TEST(cycle) {
        test_until_overflow<chmat>();
    }
}

int main(int argc, char const *argv[]) {
    bool unknown = false;
    int update_next = 0; // 1 suite, 2 epsilon
    const char *suite = "ldmat";
    const char **param = argv + 1;
    while (*param && !unknown) {
        if (update_next) {
	    if (update_next == 1) {
	        suite = *param;
	    } else {
	        double eps = atof(*param);
		if (eps > 0) {
		    epsilon = eps;
		} else {
		    cerr << "invalid epsilon value ignored" << endl;
		}
	    }

	    update_next = 0;
	} else {
	    if (!strcmp(*param, "-c")) {
	        update_next = 1;
	    } else if (!strcmp(*param, "-e")) {
	        update_next = 2;
	    } else if (!strcmp(*param, "-f")) {
		fast = true;
	    } else {
	        unknown = true;
	    }
	}

	++param;
    }

    if (unknown || update_next) {
        cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -c class ]" << endl;
    } else {
	UnitTest::TestReporterStdout reporter;
	UnitTest::TestRunner runner(reporter);
	return runner.RunTestsIf(UnitTest::Test::GetTestList(),
	    suite,
	    UnitTest::True(),
	    0);
    }
}
