root/library/tests/stresssuite/square_mat_stress.cpp

Revision 1064, 6.2 kB (checked in by mido, 14 years ago)

astyle applied all over the library

Line 
1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
3#include "base/user_info.h"
4#include "../square_mat_point.h"
5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
16using bdm::fsqmat;
17using bdm::chmat;
18using bdm::ldmat;
19using bdm::shared_ptr;
20using bdm::UIFile;
21using bdm::UI;
22
23const char *agenda_file_name = "agenda.cfg";
24double epsilon = 0.00001;
25bool fast = false;
26
27namespace UnitTest {
28
29// can't include mat_checks.h because CheckClose is different in this file
30extern bool AreClose ( const itpp::vec &expected, const itpp::vec &actual,
31                       double tolerance );
32
33extern bool AreClose ( const itpp::mat &expected, const itpp::mat &actual,
34                       double tolerance );
35
36void CheckClose ( TestResults &results, const itpp::mat &expected,
37                  const itpp::mat &actual, double tolerance,
38                  TestDetails const& details ) {
39    if ( !AreClose ( expected, actual, tolerance ) ) {
40        MemoryOutStream stream;
41        stream << "failed at " << expected.rows()
42               << " x " << expected.cols();
43
44        results.OnTestFailure ( details, stream.GetText() );
45    }
46}
47
48}
49
50typedef void ( *FTestMatrix ) ( int, square_mat_point * );
51
52template<typename TMatrix>
53void test_matrix ( int index, square_mat_point *point ) {
54    Real_Timer tt;
55
56    cout << "agenda[" << index << "]:" << endl;
57    mat A = point->get_matrix();
58    int sz = A.rows();
59    CHECK_EQUAL ( A.cols(), sz );
60
61    tt.tic();
62    TMatrix sq_mat ( A );
63    double elapsed = tt.toc();
64    cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
65
66    tt.tic();
67    mat res = sq_mat.to_mat();
68    elapsed = tt.toc();
69
70    if ( !fast ) {
71        CHECK_CLOSE ( A, res, epsilon );
72    }
73
74    cout << "to_mat: " << elapsed << " s" << endl;
75
76    vec v = point->get_vector();
77    double w = point->get_scalar();
78    TMatrix sq_mat2 = sq_mat;
79
80    tt.tic();
81    sq_mat2.opupdt ( v, w );
82    elapsed = tt.toc();
83
84    if ( !fast ) {
85        mat expA = A + w * outer_product ( v, v );
86        CHECK_CLOSE ( expA, sq_mat2.to_mat(), epsilon );
87    }
88
89    cout << "opupdt: " << elapsed << " s" << endl;
90
91    TMatrix invmat ( sz );
92
93    tt.tic();
94    sq_mat.inv ( invmat );
95    elapsed = tt.toc();
96
97    mat invA;
98    if ( !fast ) {
99        invA = inv ( A );
100        CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );
101    }
102
103    cout << "inv: " << elapsed << " s" << endl;
104
105    tt.tic();
106    double ld = sq_mat.logdet();
107    elapsed = tt.toc();
108
109    if ( !fast ) {
110        double d = det ( A );
111        CHECK_CLOSE ( log ( d ), ld, epsilon );
112    }
113
114    cout << "logdet: " << elapsed << " s" << endl;
115
116    tt.tic();
117    double q = sq_mat.qform ( ones ( sz ) );
118    elapsed = tt.toc();
119
120    if ( !fast ) {
121        CHECK_CLOSE ( sumsum ( A ), q, epsilon );
122    }
123
124    cout << "qform(1): " << elapsed << " s" << endl;
125
126    tt.tic();
127    q = sq_mat.qform ( v );
128    elapsed = tt.toc();
129
130    if ( !fast ) {
131        double r = ( A * v ) * v;
132        CHECK_CLOSE ( r, q, epsilon );
133    }
134
135    cout << "qform(v): " << elapsed << " s" << endl;
136
137    tt.tic();
138    q = sq_mat.invqform ( v );
139    elapsed = tt.toc();
140
141    if ( !fast ) {
142        double r = ( invA * v ) * v;
143        CHECK_CLOSE ( r, q, epsilon );
144    }
145
146    cout << "invqform: " << elapsed << " s" << endl;
147
148    TMatrix twice = sq_mat;
149
150    tt.tic();
151    twice += sq_mat;
152    elapsed = tt.toc();
153
154    if ( !fast ) {
155        res = 2 * A;
156        CHECK_CLOSE ( res, twice.to_mat(), epsilon );
157    }
158
159    cout << "+=: " << elapsed << " s" << endl;
160
161    sq_mat2 = sq_mat;
162
163    tt.tic();
164    sq_mat2.mult_sym ( A );
165    elapsed = tt.toc();
166
167    if ( !fast ) {
168        res = ( A * A ) * A.T();
169        CHECK_CLOSE ( res, sq_mat2.to_mat(), epsilon );
170    }
171
172    cout << "^2: " << elapsed << " s" << endl;
173}
174
175void test_agenda ( FTestMatrix test ) {
176    UIFile fag ( agenda_file_name );
177    Array<shared_ptr<square_mat_point> > mag;
178    UI::get ( mag, fag, "agenda", UI::compulsory );
179    int sz = mag.size();
180    CHECK ( sz > 0 );
181    for ( int i = 0; i < sz; ++i ) {
182        test ( i, mag ( i ).get() );
183    }
184}
185
186SUITE ( ldmat ) {
187    TEST ( agenda ) {
188        test_agenda ( test_matrix<ldmat> );
189    }
190}
191
192SUITE ( fsqmat ) {
193    TEST ( agenda ) {
194        test_agenda ( test_matrix<fsqmat> );
195    }
196}
197
198SUITE ( chmat ) {
199    TEST ( agenda ) {
200        test_agenda ( test_matrix<chmat> );
201    }
202}
203
204int main ( int argc, char const *argv[] ) {
205    bool unknown = false;
206    int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
207    const char *suite = "ldmat";
208    const char **param = argv + 1;
209    while ( *param && !unknown ) {
210        if ( update_next ) {
211            if ( update_next == 1 ) {
212                suite = *param;
213            } else if ( update_next == 2 ) {
214                double eps = atof ( *param );
215                if ( eps > 0 ) {
216                    epsilon = eps;
217                } else {
218                    cerr << "invalid epsilon value ignored" << endl;
219                }
220            } else {
221                agenda_file_name = *param;
222            }
223
224            update_next = 0;
225        } else {
226            if ( !strcmp ( *param, "-a" ) ) {
227                update_next = 3;
228            } else if ( !strcmp ( *param, "-c" ) ) {
229                update_next = 1;
230            } else if ( !strcmp ( *param, "-e" ) ) {
231                update_next = 2;
232            } else if ( !strcmp ( *param, "-f" ) ) {
233                fast = true;
234            } else {
235                unknown = true;
236            }
237        }
238
239        ++param;
240    }
241
242    if ( unknown || update_next ) {
243        cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
244    } else {
245        UnitTest::TestReporterStdout reporter;
246        UnitTest::TestRunner runner ( reporter );
247        return runner.RunTestsIf ( UnitTest::Test::GetTestList(),
248                                   suite,
249                                   UnitTest::True(),
250                                   0 );
251    }
252}
Note: See TracBrowser for help on using the browser.