root/library/tests/square_mat_stress.cpp @ 468

Revision 468, 5.4 kB (checked in by vbarta, 15 years ago)

added supplementary random test data to agenda

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "base/user_info.h"
4#include "square_mat_point.h"
5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
16using bdm::UIFile;
17using bdm::UI;
18
19const char *agenda_file_name = "agenda.cfg";
20double epsilon = 0.00001;
21bool fast = false;
22
23namespace bdm {
24UIREGISTER(square_mat_point);
25}
26
27namespace UnitTest
28{
29
30// can't include mat_checks.h because CheckClose is different in this file
31extern bool AreClose(const itpp::vec &expected, const itpp::vec &actual,
32                     double tolerance);
33
34extern bool AreClose(const itpp::mat &expected, const itpp::mat &actual,
35                     double tolerance);
36
37void CheckClose(TestResults &results, const itpp::mat &expected,
38                const itpp::mat &actual, double tolerance,
39                TestDetails const& details) {
40    if (!AreClose(expected, actual, tolerance)) { 
41        MemoryOutStream stream;
42        stream << "failed at " << expected.rows()
43               << " x " << expected.cols();
44
45        results.OnTestFailure(details, stream.GetText());
46    }
47}
48
49}
50
51typedef void (*FTestMatrix)(int, square_mat_point *);
52
53template<typename TMatrix>
54void test_matrix(int index, square_mat_point *point) {
55    Real_Timer tt;
56       
57    cout << "agenda[" << index << "]:" << endl;
58    mat A = point->get_matrix();
59    int sz = A.rows();
60    CHECK_EQUAL(A.cols(), sz);
61
62    tt.tic();
63    TMatrix sqmat(A);
64    double elapsed = tt.toc();
65    cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
66
67    tt.tic();
68    mat res = sqmat.to_mat();
69    elapsed = tt.toc();
70
71    if (!fast) {
72        CHECK_CLOSE(A, res, epsilon);
73    }
74
75    cout << "to_mat: " << elapsed << " s" << endl;
76
77    vec v = point->get_vector();
78    double w = point->get_scalar();
79    TMatrix sqmat2 = sqmat;
80       
81    tt.tic();
82    sqmat2.opupdt(v, w);
83    elapsed = tt.toc();
84
85    if (!fast) {
86        mat expA = A + w * outer_product(v, v);
87        CHECK_CLOSE(expA, sqmat2.to_mat(), epsilon);
88    }
89
90    cout << "opupdt: " << elapsed << " s" << endl;
91
92    TMatrix invmat(sz);
93
94    tt.tic();
95    sqmat.inv(invmat);
96    elapsed = tt.toc();
97
98    mat invA;
99    if (!fast) {
100        invA = inv(A);
101        CHECK_CLOSE(invA, invmat.to_mat(), epsilon);
102    }
103
104    cout << "inv: " << elapsed << " s" << endl;
105
106    tt.tic();
107    double ld = sqmat.logdet();
108    elapsed = tt.toc();
109
110    if (!fast) {
111        double d = det(A);
112        CHECK_CLOSE(log(d), ld, epsilon);
113    }
114
115    cout << "logdet: " << elapsed << " s" << endl;
116
117    tt.tic();
118    double q = sqmat.qform(ones(sz));
119    elapsed = tt.toc();
120
121    if (!fast) {
122        CHECK_CLOSE(sumsum(A), q, epsilon);
123    }
124
125    cout << "qform(1): " << elapsed << " s" << endl;
126
127    tt.tic();
128    q = sqmat.qform(v);
129    elapsed = tt.toc();
130
131    if (!fast) {
132        double r = (A * v) * v;
133        CHECK_CLOSE(r, q, epsilon);
134    }
135
136    cout << "qform(v): " << elapsed << " s" << endl;
137
138    tt.tic();
139    q = sqmat.invqform(v);
140    elapsed = tt.toc();
141
142    if (!fast) {
143        double r = (invA * v) * v;
144        CHECK_CLOSE(r, q, epsilon);
145    }
146
147    cout << "invqform: " << elapsed << " s" << endl;
148
149    TMatrix twice = sqmat;
150
151    tt.tic();
152    twice += sqmat;
153    elapsed = tt.toc();
154
155    if (!fast) {
156        res = 2 * A;
157        CHECK_CLOSE(res, twice.to_mat(), epsilon);
158    }
159
160    cout << "+=: " << elapsed << " s" << endl;
161
162    sqmat2 = sqmat;
163
164    tt.tic();
165    sqmat2.mult_sym(A);
166    elapsed = tt.toc();
167
168    if (!fast) {
169        res = (A * A) * A.T();
170        CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
171    }
172
173    cout << "^2: " << elapsed << " s" << endl;
174}
175
176void test_agenda(FTestMatrix test) {
177    UIFile fag(agenda_file_name);
178    Array<square_mat_point *> mag;
179    UI::get(mag, fag, "agenda");
180    int sz = mag.size();
181    CHECK(sz > 0);
182    for (int i = 0; i < sz; ++i) {
183        test(i, mag(i));
184    }
185
186    for (int i = 0; i < sz; ++i) {
187        square_mat_point *p = mag(i);
188        mag(i) = 0;
189        delete p;
190    }
191}
192
193SUITE(ldmat) {
194    TEST(agenda) {
195        test_agenda(test_matrix<ldmat>);
196    }
197}
198
199SUITE(fsqmat) {
200    TEST(agenda) {
201        test_agenda(test_matrix<fsqmat>);
202    }
203}
204
205SUITE(chmat) {
206    TEST(agenda) {
207        test_agenda(test_matrix<chmat>);
208    }
209}
210
211int main(int argc, char const *argv[]) {
212    bool unknown = false;
213    int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
214    const char *suite = "ldmat";
215    const char **param = argv + 1;
216    while (*param && !unknown) {
217        if (update_next) {
218            if (update_next == 1) {
219                suite = *param;
220            } else if (update_next == 2) {
221                double eps = atof(*param);
222                if (eps > 0) {
223                    epsilon = eps;
224                } else {
225                    cerr << "invalid epsilon value ignored" << endl;
226                }
227            } else {
228                agenda_file_name = *param;
229            }
230
231            update_next = 0;
232        } else {
233            if (!strcmp(*param, "-a")) {
234                update_next = 3;
235            } else if (!strcmp(*param, "-c")) {
236                update_next = 1;
237            } else if (!strcmp(*param, "-e")) {
238                update_next = 2;
239            } else if (!strcmp(*param, "-f")) {
240                fast = true;
241            } else {
242                unknown = true;
243            }
244        }
245
246        ++param;
247    }
248
249    if (unknown || update_next) {
250        cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
251    } else {
252        UnitTest::TestReporterStdout reporter;
253        UnitTest::TestRunner runner(reporter);
254        return runner.RunTestsIf(UnitTest::Test::GetTestList(),
255            suite,
256            UnitTest::True(),
257            0);
258    }
259}
Note: See TracBrowser for help on using the browser.