root/library/tests/square_mat_stress.cpp @ 480

Revision 480, 5.3 kB (checked in by vbarta, 15 years ago)

fixed tests for new UI::get & UI::build

RevLine 
[426]1#include "../bdm/math/square_mat.h"
2#include "../bdm/math/chmat.h"
[467]3#include "base/user_info.h"
[468]4#include "square_mat_point.h"
[426]5#include "UnitTest++.h"
6#include "TestReporterStdout.h"
7#include <iostream>
8#include <iomanip>
9#include <stdlib.h>
10#include <string.h>
11
12using std::cout;
13using std::cerr;
14using std::endl;
15
[467]16using bdm::UIFile;
17using bdm::UI;
18
19const char *agenda_file_name = "agenda.cfg";
[426]20double epsilon = 0.00001;
21bool fast = false;
22
[468]23namespace bdm {
[477]24UIREGISTER ( square_mat_point );
[468]25}
26
[477]27namespace UnitTest {
[426]28
[456]29// can't include mat_checks.h because CheckClose is different in this file
[477]30extern bool AreClose ( const itpp::vec &expected, const itpp::vec &actual,
31                       double tolerance );
[456]32
[477]33extern bool AreClose ( const itpp::mat &expected, const itpp::mat &actual,
34                       double tolerance );
[456]35
[477]36void CheckClose ( TestResults &results, const itpp::mat &expected,
37                  const itpp::mat &actual, double tolerance,
38                  TestDetails const& details ) {
39        if ( !AreClose ( expected, actual, tolerance ) ) {
40                MemoryOutStream stream;
41                stream << "failed at " << expected.rows()
42                << " x " << expected.cols();
[426]43
[477]44                results.OnTestFailure ( details, stream.GetText() );
45        }
[426]46}
47
48}
49
[477]50typedef void ( *FTestMatrix ) ( int, square_mat_point * );
[468]51
[426]52template<typename TMatrix>
[477]53void test_matrix ( int index, square_mat_point *point ) {
54        Real_Timer tt;
[426]55
[477]56        cout << "agenda[" << index << "]:" << endl;
57        mat A = point->get_matrix();
58        int sz = A.rows();
59        CHECK_EQUAL ( A.cols(), sz );
[426]60
[477]61        tt.tic();
62        TMatrix sqmat ( A );
63        double elapsed = tt.toc();
64        cout << "ctor(" << sz << " x " << sz << "): " << elapsed << " s" << endl;
[426]65
[477]66        tt.tic();
67        mat res = sqmat.to_mat();
68        elapsed = tt.toc();
[426]69
[477]70        if ( !fast ) {
71                CHECK_CLOSE ( A, res, epsilon );
72        }
[467]73
[477]74        cout << "to_mat: " << elapsed << " s" << endl;
[426]75
[477]76        vec v = point->get_vector();
77        double w = point->get_scalar();
78        TMatrix sqmat2 = sqmat;
[426]79
[477]80        tt.tic();
81        sqmat2.opupdt ( v, w );
82        elapsed = tt.toc();
[426]83
[477]84        if ( !fast ) {
85                mat expA = A + w * outer_product ( v, v );
86                CHECK_CLOSE ( expA, sqmat2.to_mat(), epsilon );
87        }
[438]88
[477]89        cout << "opupdt: " << elapsed << " s" << endl;
[438]90
[477]91        TMatrix invmat ( sz );
[438]92
[477]93        tt.tic();
94        sqmat.inv ( invmat );
95        elapsed = tt.toc();
[438]96
[477]97        mat invA;
98        if ( !fast ) {
99                invA = inv ( A );
100                CHECK_CLOSE ( invA, invmat.to_mat(), epsilon );
101        }
[467]102
[477]103        cout << "inv: " << elapsed << " s" << endl;
[467]104
[477]105        tt.tic();
106        double ld = sqmat.logdet();
107        elapsed = tt.toc();
[468]108
[477]109        if ( !fast ) {
110                double d = det ( A );
111                CHECK_CLOSE ( log ( d ), ld, epsilon );
112        }
[468]113
[477]114        cout << "logdet: " << elapsed << " s" << endl;
[468]115
[477]116        tt.tic();
117        double q = sqmat.qform ( ones ( sz ) );
118        elapsed = tt.toc();
[468]119
[477]120        if ( !fast ) {
121                CHECK_CLOSE ( sumsum ( A ), q, epsilon );
122        }
[468]123
[477]124        cout << "qform(1): " << elapsed << " s" << endl;
[468]125
[477]126        tt.tic();
127        q = sqmat.qform ( v );
128        elapsed = tt.toc();
[468]129
[477]130        if ( !fast ) {
131                double r = ( A * v ) * v;
132                CHECK_CLOSE ( r, q, epsilon );
133        }
[468]134
[477]135        cout << "qform(v): " << elapsed << " s" << endl;
[468]136
[477]137        tt.tic();
138        q = sqmat.invqform ( v );
139        elapsed = tt.toc();
[468]140
[477]141        if ( !fast ) {
142                double r = ( invA * v ) * v;
143                CHECK_CLOSE ( r, q, epsilon );
144        }
[468]145
[477]146        cout << "invqform: " << elapsed << " s" << endl;
[468]147
[477]148        TMatrix twice = sqmat;
[468]149
[477]150        tt.tic();
151        twice += sqmat;
152        elapsed = tt.toc();
[468]153
[477]154        if ( !fast ) {
155                res = 2 * A;
156                CHECK_CLOSE ( res, twice.to_mat(), epsilon );
157        }
[468]158
[477]159        cout << "+=: " << elapsed << " s" << endl;
[468]160
[477]161        sqmat2 = sqmat;
[468]162
[477]163        tt.tic();
164        sqmat2.mult_sym ( A );
165        elapsed = tt.toc();
166
167        if ( !fast ) {
168                res = ( A * A ) * A.T();
169                CHECK_CLOSE ( res, sqmat2.to_mat(), epsilon );
170        }
171
172        cout << "^2: " << elapsed << " s" << endl;
[426]173}
174
[477]175void test_agenda ( FTestMatrix test ) {
176        UIFile fag ( agenda_file_name );
177        Array<square_mat_point *> mag;
[480]178        UI::get ( mag, fag, "agenda", UI::compulsory );
[477]179        int sz = mag.size();
180        CHECK ( sz > 0 );
181        for ( int i = 0; i < sz; ++i ) {
182                test ( i, mag ( i ) );
183        }
[468]184
[477]185        for ( int i = 0; i < sz; ++i ) {
186                square_mat_point *p = mag ( i );
187                mag ( i ) = 0;
188                delete p;
189        }
[467]190}
191
[477]192SUITE ( ldmat ) {
193        TEST ( agenda ) {
194                test_agenda ( test_matrix<ldmat> );
195        }
[426]196}
197
[477]198SUITE ( fsqmat ) {
199        TEST ( agenda ) {
200                test_agenda ( test_matrix<fsqmat> );
201        }
[426]202}
203
[477]204SUITE ( chmat ) {
205        TEST ( agenda ) {
206                test_agenda ( test_matrix<chmat> );
207        }
[426]208}
209
[477]210int main ( int argc, char const *argv[] ) {
211        bool unknown = false;
212        int update_next = 0; // 1 suite, 2 epsilon, 3 agenda file
213        const char *suite = "ldmat";
214        const char **param = argv + 1;
215        while ( *param && !unknown ) {
216                if ( update_next ) {
217                        if ( update_next == 1 ) {
218                                suite = *param;
219                        } else if ( update_next == 2 ) {
220                                double eps = atof ( *param );
221                                if ( eps > 0 ) {
222                                        epsilon = eps;
223                                } else {
224                                        cerr << "invalid epsilon value ignored" << endl;
225                                }
226                        } else {
227                                agenda_file_name = *param;
228                        }
229
230                        update_next = 0;
[426]231                } else {
[477]232                        if ( !strcmp ( *param, "-a" ) ) {
233                                update_next = 3;
234                        } else if ( !strcmp ( *param, "-c" ) ) {
235                                update_next = 1;
236                        } else if ( !strcmp ( *param, "-e" ) ) {
237                                update_next = 2;
238                        } else if ( !strcmp ( *param, "-f" ) ) {
239                                fast = true;
240                        } else {
241                                unknown = true;
242                        }
[426]243                }
244
[477]245                ++param;
[426]246        }
247
[477]248        if ( unknown || update_next ) {
249                cerr << "usage: " << argv[0] << " [ -f ] [ -e epsilon ] [ -a agenda_input.cfg ] [ -c class ]" << endl;
250        } else {
251                UnitTest::TestReporterStdout reporter;
252                UnitTest::TestRunner runner ( reporter );
253                return runner.RunTestsIf ( UnitTest::Test::GetTestList(),
254                                           suite,
255                                           UnitTest::True(),
256                                           0 );
257        }
[426]258}
Note: See TracBrowser for help on using the browser.