#include "../bdm/math/square_mat.h"
#include "../bdm/math/chmat.h"
#include "mat_checks.h"
#include "UnitTest++.h"
#include <math.h>

const double epsilon = 0.00001;

template<typename TMatrix>
void test_square_matrix(double epsilon) {
    int sz = 3;
    mat A0 = randu(sz, sz);
    mat A = A0 * A0.T();
	
    TMatrix sqmat(A);
    CHECK_EQUAL(sz, sqmat.rows());
    CHECK_EQUAL(sz, sqmat.cols());

    mat res = sqmat.to_mat();
    CHECK_CLOSE(A, res, epsilon);

    vec v = randu(sz);
    double w = randu();
    TMatrix sqmat2 = sqmat;	
    sqmat2.opupdt(v, w);

    res = A + w * outer_product(v, v);
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);

    TMatrix invmat(sz);
    sqmat.inv(invmat);
    mat invA = inv(A);
    CHECK_CLOSE(invA, invmat.to_mat(), epsilon);

    double d = det(A);
    CHECK_CLOSE(log(d), sqmat.logdet(), epsilon);

    double q = sqmat.qform(ones(sz));
    CHECK_CLOSE(sumsum(A), q, epsilon);

    q = sqmat.qform(v);
    double r = (A * v) * v;
    CHECK_CLOSE(r, q, epsilon);

    q = sqmat.invqform(v);
    r = (invA * v) * v;
    CHECK_CLOSE(r, q, epsilon);

    sqmat2 = sqmat;
    sqmat2.clear();
    CHECK_EQUAL(0, sqmat2.qform(ones(sz)));

    TMatrix twice = sqmat;
    twice += sqmat;
    res = 2 * A;
    CHECK_CLOSE(res, twice.to_mat(), epsilon);

    twice = sqmat;
    twice *= 2;
    CHECK_CLOSE(res, twice.to_mat(), epsilon);

    sqmat2 = sqmat;
    mat B = randu(sz, sz);
    sqmat2.mult_sym(B);
    res = (B * A) * B.T();
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);

    mat C = randu(sz, sz-1);
	 TMatrix CAC(sz-1);
    sqmat.mult_sym_t(C,CAC);
    res = (C.T() * A) * C;
    CHECK_CLOSE(res, CAC.to_mat(), epsilon);

    sqmat2 = sqmat;
    sqmat2.mult_sym_t(B);
    res = (B.T() * A) * B;
    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
}

TEST(test_ldmat) {
    test_square_matrix<ldmat>(epsilon);
}

TEST(test_fsqmat) {
    test_square_matrix<fsqmat>(epsilon);
}

TEST(test_chmat) {
    test_square_matrix<chmat>(epsilon);
}
