#include "../bdm/math/square_mat.h"
#include "../bdm/math/chmat.h"
#include "mat_checks.h"
#include "UnitTest++.h"
#include <math.h>

const double epsilon = 0.00001;

bool fast = false;

namespace UnitTest
{

inline void CheckClose(TestResults &results, const itpp::mat &expected,
		       const itpp::mat &actual, double tolerance,
		       TestDetails const& details) {
    if (!AreClose(expected, actual, tolerance)) { 
        MemoryOutStream stream;
        stream << "Expected " << expected << " +/- " << tolerance << " but was " << actual;

        results.OnTestFailure(details, stream.GetText());
    }
}

}

template<typename TMatrix>
void test_square_matrix_minimum(double epsilon) {
    int sz = 3;
    mat A0 = randu(sz, sz);
    mat A = A0 * A0.T();

    TMatrix sqmat(A);
    CHECK_EQUAL(sz, sqmat.rows());
    CHECK_EQUAL(sz, sqmat.cols());

    mat res = sqmat.to_mat();
    CHECK_CLOSE(A, res, epsilon);

    vec v = randu(sz);
    double w = randu();
    TMatrix sqmat2 = sqmat;	
    sqmat2.opupdt(v, w);

    res = A + w * outer_product(v, v);
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);

    TMatrix invmat(sz);
    sqmat.inv(invmat);
    mat invA = inv(A);
    CHECK_CLOSE(invA, invmat.to_mat(), epsilon);

    double d = det(A);
    CHECK_CLOSE(log(d), sqmat.logdet(), epsilon);

    double q = sqmat.qform(ones(sz));
    CHECK_CLOSE(sumsum(A), q, epsilon);

    q = sqmat.qform(v);
    double r = (A * v) * v;
    CHECK_CLOSE(r, q, epsilon);

    sqmat2 = sqmat;
    sqmat2.clear();
    CHECK_EQUAL(0, sqmat2.qform(ones(sz)));
}

template<typename TMatrix>
void test_square_matrix(double epsilon) {
    test_square_matrix_minimum<TMatrix>(epsilon);

    int sz = 3;
    mat A0 = randu(sz, sz);
    mat A = A0 * A0.T();
	
    TMatrix sqmat(A);
    TMatrix twice = sqmat;
    twice += sqmat;
    mat res(2 * A);
    CHECK_CLOSE(res, twice.to_mat(), epsilon);

    twice = sqmat;
    twice *= 2;
    CHECK_CLOSE(res, twice.to_mat(), epsilon);

    TMatrix sqmat2 = sqmat;
    mat B = randu(sz, sz);
    sqmat2.mult_sym(B);
    res = (B * A) * B.T();
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);

    sqmat2 = sqmat;
    sqmat2.mult_sym_t(B);
    res = (B.T() * A) * B;
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
}

TEST(test_ldmat) {
    test_square_matrix<ldmat>(epsilon);
}

TEST(test_fsqmat) {
    test_square_matrix<fsqmat>(epsilon);
}

TEST(test_chmat) {
    test_square_matrix<chmat>(epsilon);
}
