root/library/tests/square_mat_stress.cpp @ 457

Revision 456, 3.2 kB (checked in by vbarta, 15 years ago)

custom test location for harness tests (extended UnitTest?++), configurable tolerance - all tests pass (most of the time)

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "UnitTest++.h"
4#include "TestReporterStdout.h"
5#include <iostream>
6#include <iomanip>
7#include <stdlib.h>
8#include <string.h>
9
10using std::cout;
11using std::cerr;
12using std::endl;
13
14double epsilon = 0.00001;
15
16bool fast = false;
17
18namespace UnitTest
19{
20
21// can't include mat_checks.h because CheckClose is different in this file
22extern bool AreClose(const itpp::vec &expected, const itpp::vec &actual,
23                     double tolerance);
24
25extern bool AreClose(const itpp::mat &expected, const itpp::mat &actual,
26                     double tolerance);
27
28void CheckClose(TestResults &results, const itpp::mat &expected,
29                const itpp::mat &actual, double tolerance,
30                TestDetails const& details) {
31    if (!AreClose(expected, actual, tolerance)) { 
32        MemoryOutStream stream;
33        stream << "failed at " << expected.rows()
34               << " x " << expected.cols();
35
36        results.OnTestFailure(details, stream.GetText());
37    }
38}
39
40}
41
42template<typename TMatrix>
43void test_until_overflow() {
44    Real_Timer tt;
45    int sz = 7;
46    while (true) {
47        mat A0 = randu(sz, sz);
48        mat A = A0 * A0.T();
49       
50        tt.tic();
51        TMatrix sqmat(A);
52        double elapsed = tt.toc();
53        cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
54
55        tt.tic();
56        mat res = sqmat.to_mat();
57        elapsed = tt.toc();
58
59        if (!fast) {
60            CHECK_CLOSE(A, res, epsilon);
61        }
62
63        cout << "to_mat: " << elapsed << " s" << endl;
64
65        vec v = randu(sz);
66        double w = randu();
67        TMatrix sqmat2 = sqmat;
68       
69        tt.tic();
70        sqmat2.opupdt(v, w);
71        elapsed = tt.toc();
72
73        if (!fast) {
74            mat expA = A + w * outer_product(v, v);
75            CHECK_CLOSE(expA, sqmat2.to_mat(), epsilon);
76        }
77
78        cout << "opupdt: " << elapsed << " s" << endl;
79
80        TMatrix invmat(sz);
81
82        tt.tic();
83        sqmat.inv(invmat);
84        elapsed = tt.toc();
85
86        if (!fast) {
87            mat invA = inv(A);
88            CHECK_CLOSE(invA, invmat.to_mat(), epsilon);
89        }
90
91        cout << "inv: " << elapsed << " s" << endl;
92
93        sz *= 7;
94    }
95}
96
97SUITE(ldmat) {
98    TEST(cycle) {
99        test_until_overflow<ldmat>();
100    }
101}
102
103SUITE(fsqmat) {
104    TEST(cycle) {
105        test_until_overflow<fsqmat>();
106    }
107}
108
109SUITE(chmat) {
110    TEST(cycle) {
111        test_until_overflow<chmat>();
112    }
113}
114
115int main(int argc, char const *argv[]) {
116    bool unknown = false;
117    int update_next = 0; // 1 suite, 2 epsilon
118    const char *suite = "ldmat";
119    const char **param = argv + 1;
120    while (*param && !unknown) {
121        if (update_next) {
122            if (update_next == 1) {
123                suite = *param;
124            } else {
125                double eps = atof(*param);
126                if (eps > 0) {
127                    epsilon = eps;
128                } else {
129                    cerr << "invalid epsilon value ignored" << endl;
130                }
131            }
132
133            update_next = 0;
134        } else {
135            if (!strcmp(*param, "-c")) {
136                update_next = 1;
137            } else if (!strcmp(*param, "-e")) {
138                update_next = 2;
139            } else if (!strcmp(*param, "-f")) {
140                fast = true;
141            } else {
142                unknown = true;
143            }
144        }
145
146        ++param;
147    }
148
149    if (unknown || update_next) {
150        cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -c class ]" << endl;
151    } else {
152        UnitTest::TestReporterStdout reporter;
153        UnitTest::TestRunner runner(reporter);
154        return runner.RunTestsIf(UnitTest::Test::GetTestList(),
155            suite,
156            UnitTest::True(),
157            0);
158    }
159}
Note: See TracBrowser for help on using the browser.