root/library/tests/square_mat_stress.cpp @ 480

Revision 480, 5.3 kB (checked in by vbarta, 15 years ago)

fixed tests for new UI::get & UI::build

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "base/user_info.h"
4#include "square_mat_point.h"
5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
16using bdm::UIFile;
17using bdm::UI;
18
19const char *agenda_file_name = "agenda.cfg";
20double epsilon = 0.00001;
21bool fast = false;
22
23namespace bdm {
24UIREGISTER ( square_mat_point );
25}
26
27namespace UnitTest {
28
29// can't include mat_checks.h because CheckClose is different in this file
30extern bool AreClose ( const itpp::vec &expected, const itpp::vec &actual,
31                       double tolerance );
32
33extern bool AreClose ( const itpp::mat &expected, const itpp::mat &actual,
34                       double tolerance );
35
36void CheckClose ( TestResults &results, const itpp::mat &expected,
37                  const itpp::mat &actual, double tolerance,
38                  TestDetails const& details ) {
39        if ( !AreClose ( expected, actual, tolerance ) ) {
40                MemoryOutStream stream;
41                stream << "failed at " << expected.rows()
42                << " x " << expected.cols();
43
44                results.OnTestFailure ( details, stream.GetText() );
45        }
46}
47
48}
49
50typedef void ( *FTestMatrix ) ( int, square_mat_point * );
51
52template<typename TMatrix>
53void test_matrix ( int index, square_mat_point *point ) {
54        Real_Timer tt;
55
56        cout << "agenda[" << index << "]:" << endl;
57        mat A = point->get_matrix();
58        int sz = A.rows();
59        CHECK_EQUAL ( A.cols(), sz );
60
61        tt.tic();
62        TMatrix sqmat ( A );
63        double elapsed = tt.toc();
64        cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
65
66        tt.tic();
67        mat res = sqmat.to_mat();
68        elapsed = tt.toc();
69
70        if ( !fast ) {
71                CHECK_CLOSE ( A, res, epsilon );
72        }
73
74        cout << "to_mat: " << elapsed << " s" << endl;
75
76        vec v = point->get_vector();
77        double w = point->get_scalar();
78        TMatrix sqmat2 = sqmat;
79
80        tt.tic();
81        sqmat2.opupdt ( v, w );
82        elapsed = tt.toc();
83
84        if ( !fast ) {
85                mat expA = A + w * outer_product ( v, v );
86                CHECK_CLOSE ( expA, sqmat2.to_mat(), epsilon );
87        }
88
89        cout << "opupdt: " << elapsed << " s" << endl;
90
91        TMatrix invmat ( sz );
92
93        tt.tic();
94        sqmat.inv ( invmat );
95        elapsed = tt.toc();
96
97        mat invA;
98        if ( !fast ) {
99                invA = inv ( A );
100                CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );
101        }
102
103        cout << "inv: " << elapsed << " s" << endl;
104
105        tt.tic();
106        double ld = sqmat.logdet();
107        elapsed = tt.toc();
108
109        if ( !fast ) {
110                double d = det ( A );
111                CHECK_CLOSE ( log ( d ), ld, epsilon );
112        }
113
114        cout << "logdet: " << elapsed << " s" << endl;
115
116        tt.tic();
117        double q = sqmat.qform ( ones ( sz ) );
118        elapsed = tt.toc();
119
120        if ( !fast ) {
121                CHECK_CLOSE ( sumsum ( A ), q, epsilon );
122        }
123
124        cout << "qform(1): " << elapsed << " s" << endl;
125
126        tt.tic();
127        q = sqmat.qform ( v );
128        elapsed = tt.toc();
129
130        if ( !fast ) {
131                double r = ( A * v ) * v;
132                CHECK_CLOSE ( r, q, epsilon );
133        }
134
135        cout << "qform(v): " << elapsed << " s" << endl;
136
137        tt.tic();
138        q = sqmat.invqform ( v );
139        elapsed = tt.toc();
140
141        if ( !fast ) {
142                double r = ( invA * v ) * v;
143                CHECK_CLOSE ( r, q, epsilon );
144        }
145
146        cout << "invqform: " << elapsed << " s" << endl;
147
148        TMatrix twice = sqmat;
149
150        tt.tic();
151        twice += sqmat;
152        elapsed = tt.toc();
153
154        if ( !fast ) {
155                res = 2 * A;
156                CHECK_CLOSE ( res, twice.to_mat(), epsilon );
157        }
158
159        cout << "+=: " << elapsed << " s" << endl;
160
161        sqmat2 = sqmat;
162
163        tt.tic();
164        sqmat2.mult_sym ( A );
165        elapsed = tt.toc();
166
167        if ( !fast ) {
168                res = ( A * A ) * A.T();
169                CHECK_CLOSE ( res, sqmat2.to_mat(), epsilon );
170        }
171
172        cout << "^2: " << elapsed << " s" << endl;
173}
174
175void test_agenda ( FTestMatrix test ) {
176        UIFile fag ( agenda_file_name );
177        Array<square_mat_point *> mag;
178        UI::get ( mag, fag, "agenda", UI::compulsory );
179        int sz = mag.size();
180        CHECK ( sz > 0 );
181        for ( int i = 0; i < sz; ++i ) {
182                test ( i, mag ( i ) );
183        }
184
185        for ( int i = 0; i < sz; ++i ) {
186                square_mat_point *p = mag ( i );
187                mag ( i ) = 0;
188                delete p;
189        }
190}
191
192SUITE ( ldmat ) {
193        TEST ( agenda ) {
194                test_agenda ( test_matrix<ldmat> );
195        }
196}
197
198SUITE ( fsqmat ) {
199        TEST ( agenda ) {
200                test_agenda ( test_matrix<fsqmat> );
201        }
202}
203
204SUITE ( chmat ) {
205        TEST ( agenda ) {
206                test_agenda ( test_matrix<chmat> );
207        }
208}
209
210int main ( int argc, char const *argv[] ) {
211        bool unknown = false;
212        int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
213        const char *suite = "ldmat";
214        const char **param = argv + 1;
215        while ( *param && !unknown ) {
216                if ( update_next ) {
217                        if ( update_next == 1 ) {
218                                suite = *param;
219                        } else if ( update_next == 2 ) {
220                                double eps = atof ( *param );
221                                if ( eps > 0 ) {
222                                        epsilon = eps;
223                                } else {
224                                        cerr << "invalid epsilon value ignored" << endl;
225                                }
226                        } else {
227                                agenda_file_name = *param;
228                        }
229
230                        update_next = 0;
231                } else {
232                        if ( !strcmp ( *param, "-a" ) ) {
233                                update_next = 3;
234                        } else if ( !strcmp ( *param, "-c" ) ) {
235                                update_next = 1;
236                        } else if ( !strcmp ( *param, "-e" ) ) {
237                                update_next = 2;
238                        } else if ( !strcmp ( *param, "-f" ) ) {
239                                fast = true;
240                        } else {
241                                unknown = true;
242                        }
243                }
244
245                ++param;
246        }
247
248        if ( unknown || update_next ) {
249                cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
250        } else {
251                UnitTest::TestReporterStdout reporter;
252                UnitTest::TestRunner runner ( reporter );
253                return runner.RunTestsIf ( UnitTest::Test::GetTestList(),
254                                           suite,
255                                           UnitTest::True(),
256                                           0 );
257        }
258}
Note: See TracBrowser for help on using the browser.