#include "../bdm/math/square_mat.h"
#include "../bdm/math/chmat.h"
#include "itpp_ext.h"
#include "../mat_checks.h"
#include "UnitTest++.h"
#include <math.h>

const double epsilon = 0.00001;

using namespace itpp;

using bdm::fsqmat;
using bdm::chmat;
using bdm::ldmat;

template<typename TMatrix>
void test_square_matrix ( double epsilon ) {
    int sz = 3;
    mat A0 = randu ( sz, sz );
    mat A = A0 * A0.T();

    // ----------- SIZES ---------
    TMatrix sq_mat ( A );
    CHECK_EQUAL ( sz, sq_mat.rows() );
    CHECK_EQUAL ( sz, sq_mat.cols() );

    // ----------- FULL MAT ---------
    mat res = sq_mat.to_mat();
    CHECK_CLOSE ( A, res, epsilon );

    // ----------- OUTER PRODUCT UPDATE ---------
    vec v = randu ( sz );
    double w = randu();
    TMatrix sq_mat2 = sq_mat;
    sq_mat2.opupdt ( v, w );

    res = A + w * outer_product ( v, v );
    CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon );

    // ----------- INVERSION ---------
    TMatrix invmat ( sz );
    sq_mat.inv ( invmat );
    mat invA = inv ( A );
    CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );

    // ----------- DETERMINANT ---------
    double d = det ( A );
    CHECK_CLOSE ( log ( d ), sq_mat.logdet(), epsilon );

    // ----------- QUADRATIC FORM ---------
    double q = sq_mat.qform ( ones ( sz ) );
    CHECK_CLOSE ( sumsum ( A ), q, epsilon );

    q = sq_mat.qform ( v );
    double r = ( A * v ) * v;
    CHECK_CLOSE ( r, q, epsilon );

    q = sq_mat.invqform ( v );
    r = ( invA * v ) * v;
    CHECK_CLOSE ( r, q, epsilon );

    sq_mat2 = sq_mat;
    sq_mat2.clear();
    CHECK_EQUAL ( 0, sq_mat2.qform ( ones ( sz ) ) );

    // ----------- + operator ---------
    TMatrix twice = sq_mat;
    twice += sq_mat;
    res = 2 * A;
    CHECK_CLOSE ( res, twice.to_mat(), epsilon );

    // ----------- * operator ---------
    twice = sq_mat;
    twice *= 2;
    CHECK_CLOSE ( res, twice.to_mat(), epsilon );

    // ----------- MULTIPLICATION ---------
    sq_mat2 = sq_mat;
    mat B = randu ( sz, sz );
    sq_mat2.mult_sym ( B );
    res = ( B * A ) * B.T();
    CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon );

    mat C = randu ( sz, sz - 1 );
    TMatrix CAC ( sz - 1 );
    sq_mat.mult_sym_t ( C, CAC );
    res = ( C.T() * A ) * C;
    CHECK_CLOSE ( res, CAC.to_mat(), epsilon );

    sq_mat2 = sq_mat;
    sq_mat2.mult_sym_t ( B );
    res = ( B.T() * A ) * B;
    CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon );

    // ----------- PERMUTATION ---------
    mat M1 = randu (sz,sz);
    mat M = M1*M1.T();
    vec perm_v_rand = randu(sz);
    ivec perm_v_ids  = sort_index(perm_v_rand);

    mat Mperm_c=M.get_cols(perm_v_ids);
    mat Mperm=Mperm_c.get_rows(perm_v_ids);

    TMatrix T(M);
    TMatrix Tperm(T,perm_v_ids);

    CHECK_CLOSE(Tperm.to_mat(), Mperm, epsilon);
}

TEST ( ldmat_test ) {
    test_square_matrix<ldmat> ( epsilon );
}

TEST ( fsqmat_test ) {
    test_square_matrix<fsqmat> ( epsilon );
}

TEST ( chmat_test ) {
    test_square_matrix<chmat> ( epsilon );
}
