#include "../bdm/math/square_mat.h" #include "../bdm/math/chmat.h" #include "itpp_ext.h" #include "../mat_checks.h" #include "UnitTest++.h" #include const double epsilon = 0.00001; using namespace itpp; using bdm::fsqmat; using bdm::chmat; using bdm::ldmat; template void test_square_matrix ( double epsilon ) { int sz = 3; mat A0 = randu ( sz, sz ); mat A = A0 * A0.T(); // ----------- SIZES --------- TMatrix sq_mat ( A ); CHECK_EQUAL ( sz, sq_mat.rows() ); CHECK_EQUAL ( sz, sq_mat.cols() ); // ----------- FULL MAT --------- mat res = sq_mat.to_mat(); CHECK_CLOSE ( A, res, epsilon ); // ----------- OUTER PRODUCT UPDATE --------- vec v = randu ( sz ); double w = randu(); TMatrix sq_mat2 = sq_mat; sq_mat2.opupdt ( v, w ); res = A + w * outer_product ( v, v ); CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon ); // ----------- INVERSION --------- TMatrix invmat ( sz ); sq_mat.inv ( invmat ); mat invA = inv ( A ); CHECK_CLOSE ( invA, invmat.to_mat(), epsilon ); // ----------- DETERMINANT --------- double d = det ( A ); CHECK_CLOSE ( log ( d ), sq_mat.logdet(), epsilon ); // ----------- QUADRATIC FORM --------- double q = sq_mat.qform ( ones ( sz ) ); CHECK_CLOSE ( sumsum ( A ), q, epsilon ); q = sq_mat.qform ( v ); double r = ( A * v ) * v; CHECK_CLOSE ( r, q, epsilon ); q = sq_mat.invqform ( v ); r = ( invA * v ) * v; CHECK_CLOSE ( r, q, epsilon ); sq_mat2 = sq_mat; sq_mat2.clear(); CHECK_EQUAL ( 0, sq_mat2.qform ( ones ( sz ) ) ); // ----------- + operator --------- TMatrix twice = sq_mat; twice += sq_mat; res = 2 * A; CHECK_CLOSE ( res, twice.to_mat(), epsilon ); // ----------- * operator --------- twice = sq_mat; twice *= 2; CHECK_CLOSE ( res, twice.to_mat(), epsilon ); // ----------- MULTIPLICATION --------- sq_mat2 = sq_mat; mat B = randu ( sz, sz ); sq_mat2.mult_sym ( B ); res = ( B * A ) * B.T(); CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon ); mat C = randu ( sz, sz - 1 ); TMatrix CAC ( sz - 1 ); sq_mat.mult_sym_t ( C, CAC ); res = ( C.T() * A ) * C; CHECK_CLOSE ( res, CAC.to_mat(), epsilon ); sq_mat2 = sq_mat; sq_mat2.mult_sym_t ( B ); res = ( B.T() * A ) * B; CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon ); // ----------- PERMUTATION --------- mat M1 = randu (sz,sz); mat M = M1*M1.T(); vec perm_v_rand = randu(sz); ivec perm_v_ids = sort_index(perm_v_rand); mat Mperm_c=M.get_cols(perm_v_ids); mat Mperm=Mperm_c.get_rows(perm_v_ids); TMatrix T(M); TMatrix Tperm(T,perm_v_ids); CHECK_CLOSE(Tperm.to_mat(), Mperm, epsilon); } TEST ( ldmat_test ) { test_square_matrix ( epsilon ); } TEST ( fsqmat_test ) { test_square_matrix ( epsilon ); } TEST ( chmat_test ) { test_square_matrix ( epsilon ); }