root/library/tests/square_mat_test.cpp @ 438

Revision 438, 2.4 kB (checked in by vbarta, 15 years ago)

testing matrix inversion

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "mat_checks.h"
4#include "UnitTest++.h"
5#include <math.h>
6
7const double epsilon = 0.00001;
8
9bool fast = false;
10
11namespace UnitTest
12{
13
14inline void CheckClose(TestResults &results, const itpp::mat &expected,
15                       const itpp::mat &actual, double tolerance,
16                       TestDetails const& details) {
17    if (!AreClose(expected, actual, tolerance)) { 
18        MemoryOutStream stream;
19        stream << "Expected " << expected << " +/- " << tolerance << " but was " << actual;
20
21        results.OnTestFailure(details, stream.GetText());
22    }
23}
24
25}
26
27template<typename TMatrix>
28void test_square_matrix_minimum(double epsilon) {
29    int sz = 3;
30    mat A0 = randu(sz, sz);
31    mat A = A0 * A0.T();
32
33    TMatrix sqmat(A);
34    CHECK_EQUAL(sz, sqmat.rows());
35    CHECK_EQUAL(sz, sqmat.cols());
36
37    mat res = sqmat.to_mat();
38    CHECK_CLOSE(A, res, epsilon);
39
40    vec v = randu(sz);
41    double w = randu();
42    TMatrix sqmat2 = sqmat;     
43    sqmat2.opupdt(v, w);
44
45    res = A + w * outer_product(v, v);
46    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
47
48    TMatrix invmat(sz);
49    sqmat.inv(invmat);
50    mat invA = inv(A);
51    CHECK_CLOSE(invA, invmat.to_mat(), epsilon);
52
53    double d = det(A);
54    CHECK_CLOSE(log(d), sqmat.logdet(), epsilon);
55
56    double q = sqmat.qform(ones(sz));
57    CHECK_CLOSE(sumsum(A), q, epsilon);
58
59    q = sqmat.qform(v);
60    double r = (A * v) * v;
61    CHECK_CLOSE(r, q, epsilon);
62
63    sqmat2 = sqmat;
64    sqmat2.clear();
65    CHECK_EQUAL(0, sqmat2.qform(ones(sz)));
66}
67
68template<typename TMatrix>
69void test_square_matrix(double epsilon) {
70    test_square_matrix_minimum<TMatrix>(epsilon);
71
72    int sz = 3;
73    mat A0 = randu(sz, sz);
74    mat A = A0 * A0.T();
75       
76    TMatrix sqmat(A);
77    TMatrix twice = sqmat;
78    twice += sqmat;
79    mat res(2 * A);
80    CHECK_CLOSE(res, twice.to_mat(), epsilon);
81
82    twice = sqmat;
83    twice *= 2;
84    CHECK_CLOSE(res, twice.to_mat(), epsilon);
85
86    TMatrix sqmat2 = sqmat;
87    mat B = randu(sz, sz);
88    sqmat2.mult_sym(B);
89    res = (B * A) * B.T();
90    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
91
92    sqmat2 = sqmat;
93    sqmat2.mult_sym_t(B);
94    res = (B.T() * A) * B;
95    CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
96}
97
98TEST(test_ldmat) {
99    test_square_matrix<ldmat>(epsilon);
100}
101
102TEST(test_fsqmat) {
103    test_square_matrix<fsqmat>(epsilon);
104}
105
106TEST(test_chmat) {
107    test_square_matrix<chmat>(epsilon);
108}
Note: See TracBrowser for help on using the browser.