root/library/bdm/math/square_mat.cpp @ 565

Revision 565, 9.3 kB (checked in by vbarta, 15 years ago)

using own error macros (basically copied from IT++, but never aborting)

  • Property svn:eol-style set to native
Line 
1
2#include "square_mat.h"
3
4namespace bdm
5{
6
7using namespace itpp;
8
9using std::endl;
10
11void fsqmat::opupdt ( const vec &v, double w ) {
12        M += outer_product ( v, v * w );
13};
14mat fsqmat::to_mat() const {
15        return M;
16};
17void fsqmat::mult_sym ( const mat &C ) {
18        M = C * M * C.T();
19};
20void fsqmat::mult_sym_t ( const mat &C ) {
21        M = C.T() * M * C;
22};
23void fsqmat::mult_sym ( const mat &C, fsqmat &U ) const {
24        U.M = ( C * ( M * C.T() ) );
25};
26void fsqmat::mult_sym_t ( const mat &C, fsqmat &U ) const {
27        U.M = ( C.T() * ( M * C ) );
28};
29void fsqmat::inv ( fsqmat &Inv ) const {
30        mat IM = itpp::inv ( M );
31        Inv = IM;
32};
33void fsqmat::clear() {
34        M.clear();
35};
36fsqmat::fsqmat ( const mat &M0 ) : sqmat ( M0.cols() ) {
37        bdm_assert_debug ( ( M0.cols() == M0.rows() ), "M0 must be square" );
38        M = M0;
39};
40
41//fsqmat::fsqmat() {};
42
43fsqmat::fsqmat ( const int dim0 ) : sqmat ( dim0 ), M ( dim0, dim0 ) {};
44
45std::ostream &operator<< ( std::ostream &os, const fsqmat &ld ) {
46        os << ld.M << endl;
47        return os;
48}
49
50
51ldmat::ldmat ( const mat &exL, const vec &exD ) : sqmat ( exD.length() ) {
52        D = exD;
53        L = exL;
54}
55
56ldmat::ldmat() : sqmat ( 0 ) {}
57
58ldmat::ldmat ( const int dim0 ) : sqmat ( dim0 ), D ( dim0 ), L ( dim0, dim0 ) {}
59
60ldmat::ldmat ( const vec D0 ) : sqmat ( D0.length() ) {
61        D = D0;
62        L = eye ( dim );
63}
64
65ldmat::ldmat ( const mat &V ) : sqmat ( V.cols() ) {
66
67        bdm_assert_debug ( dim == V.rows(), "ldmat::ldmat matrix V is not square!" );
68
69        // L and D will be allocated by ldform()
70        //Chol is unstable
71        this->ldform ( chol ( V ), ones ( dim ) );
72}
73
74void ldmat::opupdt ( const vec &v,  double w ) {
75        int dim = D.length();
76        double kr;
77        vec r = v;
78        //beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
79        double *Lraw = L._data();
80        double *Draw = D._data();
81        double *rraw = r._data();
82
83        bdm_assert_debug ( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );
84
85        for ( int i = dim - 1; i >= 0; i-- ) {
86                dydr ( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
87        }
88}
89
90std::ostream &operator<< ( std::ostream &os, const ldmat &ld ) {
91        os << "L:" << ld.L << endl;
92        os << "D:" << ld.D << endl;
93        return os;
94}
95
96mat ldmat::to_mat() const {
97        int dim = D.length();
98        mat V ( dim, dim );
99        double sum;
100        int r, c, cc;
101
102        for ( r = 0; r < dim; r++ ) { //row cycle
103                for ( c = r; c < dim; c++ ) {
104                        //column cycle, using symmetricity => c=r!
105                        sum = 0.0;
106                        for ( cc = c; cc < dim; cc++ ) { //cycle over the remaining part of the vector
107                                sum += L ( cc, r ) * D ( cc ) * L ( cc, c );
108                                //here L(cc,r) = L(r,cc)';
109                        }
110                        V ( r, c ) = sum;
111                        // symmetricity
112                        if ( r != c ) {
113                                V ( c, r ) = sum;
114                        };
115                }
116        }
117        mat V2 = L.transpose() * diag ( D ) * L;
118        return V2;
119}
120
121
122void ldmat::add ( const ldmat &ld2, double w ) {
123        int dim = D.length();
124
125        bdm_assert_debug ( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs" );
126
127        //Fixme can be done more efficiently either via dydr or ldform
128        for ( int r = 0; r < dim; r++ ) {
129                // Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
130                this->opupdt ( ld2.L.get_row ( r ), w*ld2.D ( r ) );
131        }
132}
133
134void ldmat::clear() {
135        L.clear();
136        for ( int i = 0; i < L.cols(); i++ ) {
137                L ( i, i ) = 1;
138        };
139        D.clear();
140}
141
142void ldmat::inv ( ldmat &Inv ) const {
143        Inv.clear();   //Inv = zero in LD
144        mat U = ltuinv ( L );
145
146        Inv.ldform ( U.transpose(), 1.0 / D );
147}
148
149void ldmat::mult_sym ( const mat &C ) {
150        mat A = L * C.T();
151        this->ldform ( A, D );
152}
153
154void ldmat::mult_sym_t ( const mat &C ) {
155        mat A = L * C;
156        this->ldform ( A, D );
157}
158
159void ldmat::mult_sym ( const mat &C, ldmat &U ) const {
160        mat A = L * C.T(); //could be done more efficiently using BLAS
161        U.ldform ( A, D );
162}
163
164void ldmat::mult_sym_t ( const mat &C, ldmat &U ) const {
165        mat A = L * C;
166        /*      vec nD=zeros(U.rows());
167                nD.replace_mid(0, D); //I case that D < nD*/
168        U.ldform ( A, D );
169}
170
171
172double ldmat::logdet() const {
173        double ldet = 0.0;
174        int i;
175// sum logarithms of diagobal elements
176        for ( i = 0; i < D.length(); i++ ) {
177                ldet += log ( D ( i ) );
178        };
179        return ldet;
180}
181
182double ldmat::qform ( const vec &v ) const {
183        double x = 0.0, sum;
184        int i, j;
185
186        for ( i = 0; i < D.length(); i++ ) { //rows of L
187                sum = 0.0;
188                for ( j = 0; j <= i; j++ ) {
189                        sum += L ( i, j ) * v ( j );
190                }
191                x += D ( i ) * sum * sum;
192        };
193        return x;
194}
195
196double ldmat::invqform ( const vec &v ) const {
197        double x = 0.0;
198        int i;
199        vec pom ( v.length() );
200
201        backward_substitution ( L.T(), v, pom );
202
203        for ( i = 0; i < D.length(); i++ ) { //rows of L
204                x += pom ( i ) * pom ( i ) / D ( i );
205        };
206        return x;
207}
208
209ldmat& ldmat::operator *= ( double x ) {
210        D *= x;
211        return *this;
212}
213
214vec ldmat::sqrt_mult ( const vec &x ) const {
215        int i, j;
216        vec res ( dim );
217        //double sum;
218        for ( i = 0; i < dim; i++ ) {//for each element of result
219                res ( i ) = 0.0;
220                for ( j = i; j < dim; j++ ) {//sum D(j)*L(:,i).*x
221                        res ( i ) += sqrt ( D ( j ) ) * L ( j, i ) * x ( j );
222                }
223        }
224//      vec res2 = L.transpose()*diag( sqrt( D ) )*x;
225        return res;
226}
227
228void ldmat::ldform ( const mat &A, const vec &D0 ) {
229        int m = A.rows();
230        int n = A.cols();
231        int mn = ( m < n ) ? m : n ;
232
233        bdm_assert_debug ( D0.length() == A.rows(), "ldmat::ldform Vector D must have the length as row count of A" );
234
235        L = concat_vertical ( zeros ( n, n ), diag ( sqrt ( D0 ) ) * A );
236        D = zeros ( n + m );
237
238        //unnecessary big L and D will be made smaller at the end of file
239        vec w = zeros ( n + m );
240
241        double sum, beta, pom;
242
243        int cc = 0;
244        int i = n; // indexovani o 1 niz, nez v matlabu
245        int ii, j, jj;
246        while ( ( i > n - mn - cc ) && ( i > 0 ) ) {
247                i--;
248                sum = 0.0;
249
250                int last_v = m + i - n + cc + 1;
251
252                vec v = zeros ( last_v + 1 ); //prepare v
253                for ( ii = n - cc - 1; ii < m + i + 1; ii++ ) {
254                        sum += L ( ii, i ) * L ( ii, i );
255                        v ( ii - n + cc + 1 ) = L ( ii, i ); //assign v
256                }
257
258                if ( L ( m + i, i ) == 0 )
259                        beta = sqrt ( sum );
260                else
261                        beta = L ( m + i, i ) + sign ( L ( m + i, i ) ) * sqrt ( sum );
262
263                if ( std::fabs ( beta ) < eps ) {
264                        cc++;
265                        L.set_row ( n - cc, L.get_row ( m + i ) );
266                        L.set_row ( m + i, zeros ( L.cols() ) );
267                        D ( m + i ) = 0;
268                        L ( m + i, i ) = 1;
269                        L.set_submatrix ( n - cc, m + i - 1, i, i, 0 );
270                        continue;
271                }
272
273                sum -= v ( last_v ) * v ( last_v );
274                sum /= beta * beta;
275                sum++;
276
277                v /= beta;
278                v ( last_v ) = 1;
279
280                pom = -2.0 / sum;
281                // echo to venca
282
283                for ( j = i; j >= 0; j-- ) {
284                        double w_elem = 0;
285                        for ( ii = n -  cc; ii <= m + i + 1; ii++ )
286                                w_elem += v ( ii - n + cc ) * L ( ii - 1, j );
287                        w ( j ) = w_elem * pom;
288                }
289
290                for ( ii = n - cc - 1; ii <= m + i; ii++ )
291                        for ( jj = 0; jj < i; jj++ )
292                                L ( ii, jj ) += v ( ii - n + cc + 1 ) * w ( jj );
293
294                for ( ii = n - cc - 1; ii < m + i; ii++ )
295                        L ( ii, i ) = 0;
296
297                L ( m + i, i ) += w ( i );
298                D ( m + i ) = L ( m + i, i ) * L ( m + i, i );
299
300                for ( ii = 0; ii <= i; ii++ )
301                        L ( m + i, ii ) /= L ( m + i, i );
302        }
303
304        if ( i >= 0 )
305                for ( ii = 0; ii < i; ii++ ) {
306                        jj = D.length() - 1 - n + ii;
307                        D ( jj ) = 0;
308                        L.set_row ( jj, zeros ( L.cols() ) ); //TODO: set_row accepts Num_T
309                        L ( jj, jj ) = 1;
310                }
311
312        L.del_rows ( 0, m - 1 );
313        D.del ( 0, m - 1 );
314
315        dim = L.rows();
316}
317
318//////// Auxiliary Functions
319
320mat ltuinv ( const mat &L ) {
321        int dim = L.cols();
322        mat Il = eye ( dim );
323        int i, j, k, m;
324        double s;
325
326//Fixme blind transcription of ltuinv.m
327        for ( k = 1; k < ( dim ); k++ ) {
328                for ( i = 0; i < ( dim - k ); i++ ) {
329                        j = i + k; //change in .m 1+1=2, here 0+0+1=1
330                        s = L ( j, i );
331                        for ( m = i + 1; m < ( j ); m++ ) {
332                                s += L ( m, i ) * Il ( j, m );
333                        }
334                        Il ( j, i ) = -s;
335                }
336        }
337
338        return Il;
339}
340
341void dydr ( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
342/********************************************************************
343
344   dydr = dyadic reduction, performs transformation of sum of
345          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
346          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.
347
348   Parameters :
349     r ... pointer to reduced dyad
350     f ... pointer to reducing dyad
351     Dr .. pointer to the weight of reduced dyad
352     Df .. pointer to the weight of reducing dyad
353     R ... pointer to the element of r, which is to be reduced to
354           zero; the corresponding element of f is assumed to be 1.
355     jl .. lower index of the range within which the dyads are
356           modified
357     ju .. upper index of the range within which the dyads are
358           modified
359     kr .. pointer to the coefficient used in the transformation of r
360           rnew = r + kr*f
361     m  .. number of rows of modified matrix (part of which is r)
362  Remark : Constant mzero means machine zero and should be modified
363           according to the precision of particular machine
364
365                                                 V. Peterka 17-7-89
366
367  Added:
368     mx .. number of rows of modified matrix (part of which is f)  -PN
369
370********************************************************************/
371{
372        int j, jm;
373        double kD, r0;
374        double mzero = 2.2e-16;
375        double threshold = 1e-4;
376
377        if ( fabs ( *Dr ) < mzero ) *Dr = 0;
378        r0 = *R;
379        *R = 0.0;
380        kD = *Df;
381        *kr = r0 * *Dr;
382        *Df = kD + r0 * ( *kr );
383        if ( *Df > mzero ) {
384                kD /= *Df;
385                *kr /= *Df;
386        } else {
387                kD = 1.0;
388                *kr = 0.0;
389                if ( *Df < -threshold ) {
390                        bdm_warning ( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
391                }
392                *Df = 0.0;
393        }
394        *Dr *= kD;
395        jm = mx * jl;
396        for ( j = m * jl; j < m*jh; j += m ) {
397                r[j] -=  r0 * f[jm];
398                f[jm] += *kr * r[j];
399                jm += mx;
400        }
401}
402
403}
Note: See TracBrowser for help on using the browser.