root/library/tests/square_mat_stress.cpp @ 468

Revision 468, 5.4 kB (checked in by vbarta, 15 years ago)

added supplementary random test data to agenda

RevLine 
[426]1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
[467]3#include "base/user_info.h"
[468]4#include "square_mat_point.h"
[426]5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
[467]16using bdm::UIFile;
17using bdm::UI;
18
19const char *agenda_file_name = "agenda.cfg";
[426]20double epsilon = 0.00001;
21bool fast = false;
22
[468]23namespace bdm {
24UIREGISTER(square_mat_point);
25}
26
[426]27namespace UnitTest
28{
29
[456]30// can't include mat_checks.h because CheckClose is different in this file
31extern bool AreClose(const itpp::vec &expected, const itpp::vec &actual,
32                     double tolerance);
33
34extern bool AreClose(const itpp::mat &expected, const itpp::mat &actual,
35                     double tolerance);
36
[426]37void CheckClose(TestResults &results, const itpp::mat &expected,
38                const itpp::mat &actual, double tolerance,
39                TestDetails const& details) {
40    if (!AreClose(expected, actual, tolerance)) { 
41        MemoryOutStream stream;
42        stream << "failed at " << expected.rows()
43               << " x " << expected.cols();
44
45        results.OnTestFailure(details, stream.GetText());
46    }
47}
48
49}
50
[468]51typedef void (*FTestMatrix)(int, square_mat_point *);
52
[426]53template<typename TMatrix>
[468]54void test_matrix(int index, square_mat_point *point) {
[426]55    Real_Timer tt;
56       
[468]57    cout << "agenda[" << index << "]:" << endl;
58    mat A = point->get_matrix();
[467]59    int sz = A.rows();
60    CHECK_EQUAL(A.cols(), sz);
[426]61
[467]62    tt.tic();
63    TMatrix sqmat(A);
64    double elapsed = tt.toc();
65    cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
[426]66
[467]67    tt.tic();
68    mat res = sqmat.to_mat();
69    elapsed = tt.toc();
[426]70
[467]71    if (!fast) {
72        CHECK_CLOSE(A, res, epsilon);
73    }
[426]74
[467]75    cout << "to_mat: " << elapsed << " s" << endl;
76
[468]77    vec v = point->get_vector();
78    double w = point->get_scalar();
[467]79    TMatrix sqmat2 = sqmat;
[426]80       
[467]81    tt.tic();
82    sqmat2.opupdt(v, w);
83    elapsed = tt.toc();
[426]84
[467]85    if (!fast) {
86        mat expA = A + w * outer_product(v, v);
87        CHECK_CLOSE(expA, sqmat2.to_mat(), epsilon);
88    }
[426]89
[467]90    cout << "opupdt: " << elapsed << " s" << endl;
[426]91
[467]92    TMatrix invmat(sz);
[438]93
[467]94    tt.tic();
95    sqmat.inv(invmat);
96    elapsed = tt.toc();
[438]97
[468]98    mat invA;
[467]99    if (!fast) {
[468]100        invA = inv(A);
[467]101        CHECK_CLOSE(invA, invmat.to_mat(), epsilon);
102    }
[438]103
[467]104    cout << "inv: " << elapsed << " s" << endl;
[438]105
[467]106    tt.tic();
107    double ld = sqmat.logdet();
108    elapsed = tt.toc();
109
110    if (!fast) {
111        double d = det(A);
112        CHECK_CLOSE(log(d), ld, epsilon);
[426]113    }
[467]114
115    cout << "logdet: " << elapsed << " s" << endl;
[468]116
117    tt.tic();
118    double q = sqmat.qform(ones(sz));
119    elapsed = tt.toc();
120
121    if (!fast) {
122        CHECK_CLOSE(sumsum(A), q, epsilon);
123    }
124
125    cout << "qform(1): " << elapsed << " s" << endl;
126
127    tt.tic();
128    q = sqmat.qform(v);
129    elapsed = tt.toc();
130
131    if (!fast) {
132        double r = (A * v) * v;
133        CHECK_CLOSE(r, q, epsilon);
134    }
135
136    cout << "qform(v): " << elapsed << " s" << endl;
137
138    tt.tic();
139    q = sqmat.invqform(v);
140    elapsed = tt.toc();
141
142    if (!fast) {
143        double r = (invA * v) * v;
144        CHECK_CLOSE(r, q, epsilon);
145    }
146
147    cout << "invqform: " << elapsed << " s" << endl;
148
149    TMatrix twice = sqmat;
150
151    tt.tic();
152    twice += sqmat;
153    elapsed = tt.toc();
154
155    if (!fast) {
156        res = 2 * A;
157        CHECK_CLOSE(res, twice.to_mat(), epsilon);
158    }
159
160    cout << "+=: " << elapsed << " s" << endl;
161
162    sqmat2 = sqmat;
163
164    tt.tic();
165    sqmat2.mult_sym(A);
166    elapsed = tt.toc();
167
168    if (!fast) {
169        res = (A * A) * A.T();
170        CHECK_CLOSE(res, sqmat2.to_mat(), epsilon);
171    }
172
173    cout << "^2: " << elapsed << " s" << endl;
[426]174}
175
[468]176void test_agenda(FTestMatrix test) {
[467]177    UIFile fag(agenda_file_name);
[468]178    Array<square_mat_point *> mag;
[467]179    UI::get(mag, fag, "agenda");
180    int sz = mag.size();
181    CHECK(sz > 0);
182    for (int i = 0; i < sz; ++i) {
[468]183        test(i, mag(i));
[467]184    }
[468]185
186    for (int i = 0; i < sz; ++i) {
187        square_mat_point *p = mag(i);
188        mag(i) = 0;
189        delete p;
190    }
[467]191}
192
[426]193SUITE(ldmat) {
[468]194    TEST(agenda) {
195        test_agenda(test_matrix<ldmat>);
[426]196    }
197}
198
199SUITE(fsqmat) {
[468]200    TEST(agenda) {
201        test_agenda(test_matrix<fsqmat>);
[426]202    }
203}
204
205SUITE(chmat) {
[468]206    TEST(agenda) {
207        test_agenda(test_matrix<chmat>);
[426]208    }
209}
210
211int main(int argc, char const *argv[]) {
212    bool unknown = false;
[467]213    int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
[426]214    const char *suite = "ldmat";
215    const char **param = argv + 1;
216    while (*param && !unknown) {
217        if (update_next) {
218            if (update_next == 1) {
219                suite = *param;
[467]220            } else if (update_next == 2) {
[426]221                double eps = atof(*param);
222                if (eps > 0) {
223                    epsilon = eps;
224                } else {
225                    cerr << "invalid epsilon value ignored" << endl;
226                }
[467]227            } else {
228                agenda_file_name = *param;
[426]229            }
230
231            update_next = 0;
232        } else {
[468]233            if (!strcmp(*param, "-a")) {
234                update_next = 3;
235            } else if (!strcmp(*param, "-c")) {
[426]236                update_next = 1;
237            } else if (!strcmp(*param, "-e")) {
238                update_next = 2;
239            } else if (!strcmp(*param, "-f")) {
240                fast = true;
241            } else {
242                unknown = true;
243            }
244        }
245
246        ++param;
247    }
248
249    if (unknown || update_next) {
[468]250        cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
[426]251    } else {
252        UnitTest::TestReporterStdout reporter;
253        UnitTest::TestRunner runner(reporter);
254        return runner.RunTestsIf(UnitTest::Test::GetTestList(),
255            suite,
256            UnitTest::True(),
257            0);
258    }
259}
Note: See TracBrowser for help on using the browser.