root/library/tests/square_mat_stress.cpp @ 495

Revision 495, 5.3 kB (checked in by vbarta, 15 years ago)

moved square matrices to namespace bdm

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "base/user_info.h"
4#include "square_mat_point.h"
5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
16using bdm::fsqmat;
17using bdm::chmat;
18using bdm::ldmat;
19using bdm::UIFile;
20using bdm::UI;
21
22const char *agenda_file_name = "agenda.cfg";
23double epsilon = 0.00001;
24bool fast = false;
25
26namespace UnitTest {
27
28// can't include mat_checks.h because CheckClose is different in this file
29extern bool AreClose ( const itpp::vec &expected, const itpp::vec &actual,
30                       double tolerance );
31
32extern bool AreClose ( const itpp::mat &expected, const itpp::mat &actual,
33                       double tolerance );
34
35void CheckClose ( TestResults &results, const itpp::mat &expected,
36                  const itpp::mat &actual, double tolerance,
37                  TestDetails const& details ) {
38        if ( !AreClose ( expected, actual, tolerance ) ) {
39                MemoryOutStream stream;
40                stream << "failed at " << expected.rows()
41                << " x " << expected.cols();
42
43                results.OnTestFailure ( details, stream.GetText() );
44        }
45}
46
47}
48
49typedef void ( *FTestMatrix ) ( int, square_mat_point * );
50
51template<typename TMatrix>
52void test_matrix ( int index, square_mat_point *point ) {
53        Real_Timer tt;
54
55        cout << "agenda[" << index << "]:" << endl;
56        mat A = point->get_matrix();
57        int sz = A.rows();
58        CHECK_EQUAL ( A.cols(), sz );
59
60        tt.tic();
61        TMatrix sqmat ( A );
62        double elapsed = tt.toc();
63        cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
64
65        tt.tic();
66        mat res = sqmat.to_mat();
67        elapsed = tt.toc();
68
69        if ( !fast ) {
70                CHECK_CLOSE ( A, res, epsilon );
71        }
72
73        cout << "to_mat: " << elapsed << " s" << endl;
74
75        vec v = point->get_vector();
76        double w = point->get_scalar();
77        TMatrix sqmat2 = sqmat;
78
79        tt.tic();
80        sqmat2.opupdt ( v, w );
81        elapsed = tt.toc();
82
83        if ( !fast ) {
84                mat expA = A + w * outer_product ( v, v );
85                CHECK_CLOSE ( expA, sqmat2.to_mat(), epsilon );
86        }
87
88        cout << "opupdt: " << elapsed << " s" << endl;
89
90        TMatrix invmat ( sz );
91
92        tt.tic();
93        sqmat.inv ( invmat );
94        elapsed = tt.toc();
95
96        mat invA;
97        if ( !fast ) {
98                invA = inv ( A );
99                CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );
100        }
101
102        cout << "inv: " << elapsed << " s" << endl;
103
104        tt.tic();
105        double ld = sqmat.logdet();
106        elapsed = tt.toc();
107
108        if ( !fast ) {
109                double d = det ( A );
110                CHECK_CLOSE ( log ( d ), ld, epsilon );
111        }
112
113        cout << "logdet: " << elapsed << " s" << endl;
114
115        tt.tic();
116        double q = sqmat.qform ( ones ( sz ) );
117        elapsed = tt.toc();
118
119        if ( !fast ) {
120                CHECK_CLOSE ( sumsum ( A ), q, epsilon );
121        }
122
123        cout << "qform(1): " << elapsed << " s" << endl;
124
125        tt.tic();
126        q = sqmat.qform ( v );
127        elapsed = tt.toc();
128
129        if ( !fast ) {
130                double r = ( A * v ) * v;
131                CHECK_CLOSE ( r, q, epsilon );
132        }
133
134        cout << "qform(v): " << elapsed << " s" << endl;
135
136        tt.tic();
137        q = sqmat.invqform ( v );
138        elapsed = tt.toc();
139
140        if ( !fast ) {
141                double r = ( invA * v ) * v;
142                CHECK_CLOSE ( r, q, epsilon );
143        }
144
145        cout << "invqform: " << elapsed << " s" << endl;
146
147        TMatrix twice = sqmat;
148
149        tt.tic();
150        twice += sqmat;
151        elapsed = tt.toc();
152
153        if ( !fast ) {
154                res = 2 * A;
155                CHECK_CLOSE ( res, twice.to_mat(), epsilon );
156        }
157
158        cout << "+=: " << elapsed << " s" << endl;
159
160        sqmat2 = sqmat;
161
162        tt.tic();
163        sqmat2.mult_sym ( A );
164        elapsed = tt.toc();
165
166        if ( !fast ) {
167                res = ( A * A ) * A.T();
168                CHECK_CLOSE ( res, sqmat2.to_mat(), epsilon );
169        }
170
171        cout << "^2: " << elapsed << " s" << endl;
172}
173
174void test_agenda ( FTestMatrix test ) {
175        UIFile fag ( agenda_file_name );
176        Array<square_mat_point *> mag;
177        UI::get ( mag, fag, "agenda", UI::compulsory );
178        int sz = mag.size();
179        CHECK ( sz > 0 );
180        for ( int i = 0; i < sz; ++i ) {
181                test ( i, mag ( i ) );
182        }
183
184        for ( int i = 0; i < sz; ++i ) {
185                square_mat_point *p = mag ( i );
186                mag ( i ) = 0;
187                delete p;
188        }
189}
190
191SUITE ( ldmat ) {
192        TEST ( agenda ) {
193                test_agenda ( test_matrix<ldmat> );
194        }
195}
196
197SUITE ( fsqmat ) {
198        TEST ( agenda ) {
199                test_agenda ( test_matrix<fsqmat> );
200        }
201}
202
203SUITE ( chmat ) {
204        TEST ( agenda ) {
205                test_agenda ( test_matrix<chmat> );
206        }
207}
208
209int main ( int argc, char const *argv[] ) {
210        bool unknown = false;
211        int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
212        const char *suite = "ldmat";
213        const char **param = argv + 1;
214        while ( *param && !unknown ) {
215                if ( update_next ) {
216                        if ( update_next == 1 ) {
217                                suite = *param;
218                        } else if ( update_next == 2 ) {
219                                double eps = atof ( *param );
220                                if ( eps > 0 ) {
221                                        epsilon = eps;
222                                } else {
223                                        cerr << "invalid epsilon value ignored" << endl;
224                                }
225                        } else {
226                                agenda_file_name = *param;
227                        }
228
229                        update_next = 0;
230                } else {
231                        if ( !strcmp ( *param, "-a" ) ) {
232                                update_next = 3;
233                        } else if ( !strcmp ( *param, "-c" ) ) {
234                                update_next = 1;
235                        } else if ( !strcmp ( *param, "-e" ) ) {
236                                update_next = 2;
237                        } else if ( !strcmp ( *param, "-f" ) ) {
238                                fast = true;
239                        } else {
240                                unknown = true;
241                        }
242                }
243
244                ++param;
245        }
246
247        if ( unknown || update_next ) {
248                cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
249        } else {
250                UnitTest::TestReporterStdout reporter;
251                UnitTest::TestRunner runner ( reporter );
252                return runner.RunTestsIf ( UnitTest::Test::GetTestList(),
253                                           suite,
254                                           UnitTest::True(),
255                                           0 );
256        }
257}
Note: See TracBrowser for help on using the browser.