root/library/bdm/math/square_mat.cpp @ 495

Revision 495, 9.4 kB (checked in by vbarta, 15 years ago)

moved square matrices to namespace bdm

  • Property svn:eol-style set to native
Line 
1
2#include "square_mat.h"
3
4namespace bdm
5{
6
7using namespace itpp;
8
9using std::endl;
10
11void fsqmat::opupdt ( const vec &v, double w ) {
12        M += outer_product ( v, v * w );
13};
14mat fsqmat::to_mat() const {
15        return M;
16};
17void fsqmat::mult_sym ( const mat &C ) {
18        M = C * M * C.T();
19};
20void fsqmat::mult_sym_t ( const mat &C ) {
21        M = C.T() * M * C;
22};
23void fsqmat::mult_sym ( const mat &C, fsqmat &U ) const {
24        U.M = ( C * ( M * C.T() ) );
25};
26void fsqmat::mult_sym_t ( const mat &C, fsqmat &U ) const {
27        U.M = ( C.T() * ( M * C ) );
28};
29void fsqmat::inv ( fsqmat &Inv ) const {
30        mat IM = itpp::inv ( M );
31        Inv = IM;
32};
33void fsqmat::clear() {
34        M.clear();
35};
36fsqmat::fsqmat ( const mat &M0 ) : sqmat ( M0.cols() ) {
37        it_assert_debug ( ( M0.cols() == M0.rows() ), "M0 must be square" );
38        M = M0;
39};
40
41//fsqmat::fsqmat() {};
42
43fsqmat::fsqmat ( const int dim0 ) : sqmat ( dim0 ), M ( dim0, dim0 ) {};
44
45std::ostream &operator<< ( std::ostream &os, const fsqmat &ld ) {
46        os << ld.M << endl;
47        return os;
48}
49
50
51ldmat::ldmat ( const mat &exL, const vec &exD ) : sqmat ( exD.length() ) {
52        D = exD;
53        L = exL;
54}
55
56ldmat::ldmat() : sqmat ( 0 ) {}
57
58ldmat::ldmat ( const int dim0 ) : sqmat ( dim0 ), D ( dim0 ), L ( dim0, dim0 ) {}
59
60ldmat::ldmat ( const vec D0 ) : sqmat ( D0.length() ) {
61        D = D0;
62        L = eye ( dim );
63}
64
65ldmat::ldmat ( const mat &V ) : sqmat ( V.cols() ) {
66//TODO check if correct!! Based on heuristic observation of lu()
67
68        it_assert_debug ( dim == V.rows(), "ldmat::ldmat matrix V is not square!" );
69
70        // L and D will be allocated by ldform()
71
72        //Chol is unstable
73        this->ldform ( chol ( V ), ones ( dim ) );
74//      this->ldform(ul(V),ones(dim));
75}
76
77void ldmat::opupdt ( const vec &v,  double w ) {
78        int dim = D.length();
79        double kr;
80        vec r = v;
81        //beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
82        double *Lraw = L._data();
83        double *Draw = D._data();
84        double *rraw = r._data();
85
86        it_assert_debug ( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );
87
88        for ( int i = dim - 1; i >= 0; i-- ) {
89                dydr ( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
90        }
91}
92
93std::ostream &operator<< ( std::ostream &os, const ldmat &ld ) {
94        os << "L:" << ld.L << endl;
95        os << "D:" << ld.D << endl;
96        return os;
97}
98
99mat ldmat::to_mat() const {
100        int dim = D.length();
101        mat V ( dim, dim );
102        double sum;
103        int r, c, cc;
104
105        for ( r = 0; r < dim; r++ ) { //row cycle
106                for ( c = r; c < dim; c++ ) {
107                        //column cycle, using symmetricity => c=r!
108                        sum = 0.0;
109                        for ( cc = c; cc < dim; cc++ ) { //cycle over the remaining part of the vector
110                                sum += L ( cc, r ) * D ( cc ) * L ( cc, c );
111                                //here L(cc,r) = L(r,cc)';
112                        }
113                        V ( r, c ) = sum;
114                        // symmetricity
115                        if ( r != c ) {
116                                V ( c, r ) = sum;
117                        };
118                }
119        }
120        mat V2 = L.transpose() * diag ( D ) * L;
121        return V2;
122}
123
124
125void ldmat::add ( const ldmat &ld2, double w ) {
126        int dim = D.length();
127
128        it_assert_debug ( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );
129
130        //Fixme can be done more efficiently either via dydr or ldform
131        for ( int r = 0; r < dim; r++ ) {
132                // Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
133                this->opupdt ( ld2.L.get_row ( r ), w*ld2.D ( r ) );
134        }
135}
136
137void ldmat::clear() {
138        L.clear();
139        for ( int i = 0; i < L.cols(); i++ ) {
140                L ( i, i ) = 1;
141        };
142        D.clear();
143}
144
145void ldmat::inv ( ldmat &Inv ) const {
146        Inv.clear();   //Inv = zero in LD
147        mat U = ltuinv ( L );
148
149        Inv.ldform ( U.transpose(), 1.0 / D );
150}
151
152void ldmat::mult_sym ( const mat &C ) {
153        mat A = L * C.T();
154        this->ldform ( A, D );
155}
156
157void ldmat::mult_sym_t ( const mat &C ) {
158        mat A = L * C;
159        this->ldform ( A, D );
160}
161
162void ldmat::mult_sym ( const mat &C, ldmat &U ) const {
163        mat A = L * C.T(); //could be done more efficiently using BLAS
164        U.ldform ( A, D );
165}
166
167void ldmat::mult_sym_t ( const mat &C, ldmat &U ) const {
168        mat A = L * C;
169        /*      vec nD=zeros(U.rows());
170                nD.replace_mid(0, D); //I case that D < nD*/
171        U.ldform ( A, D );
172}
173
174
175double ldmat::logdet() const {
176        double ldet = 0.0;
177        int i;
178// sum logarithms of diagobal elements
179        for ( i = 0; i < D.length(); i++ ) {
180                ldet += log ( D ( i ) );
181        };
182        return ldet;
183}
184
185double ldmat::qform ( const vec &v ) const {
186        double x = 0.0, sum;
187        int i, j;
188
189        for ( i = 0; i < D.length(); i++ ) { //rows of L
190                sum = 0.0;
191                for ( j = 0; j <= i; j++ ) {
192                        sum += L ( i, j ) * v ( j );
193                }
194                x += D ( i ) * sum * sum;
195        };
196        return x;
197}
198
199double ldmat::invqform ( const vec &v ) const {
200        double x = 0.0;
201        int i;
202        vec pom ( v.length() );
203
204        backward_substitution ( L.T(), v, pom );
205
206        for ( i = 0; i < D.length(); i++ ) { //rows of L
207                x += pom ( i ) * pom ( i ) / D ( i );
208        };
209        return x;
210}
211
212ldmat& ldmat::operator *= ( double x ) {
213        D *= x;
214        return *this;
215}
216
217vec ldmat::sqrt_mult ( const vec &x ) const {
218        int i, j;
219        vec res ( dim );
220        //double sum;
221        for ( i = 0; i < dim; i++ ) {//for each element of result
222                res ( i ) = 0.0;
223                for ( j = i; j < dim; j++ ) {//sum D(j)*L(:,i).*x
224                        res ( i ) += sqrt ( D ( j ) ) * L ( j, i ) * x ( j );
225                }
226        }
227//      vec res2 = L.transpose()*diag( sqrt( D ) )*x;
228        return res;
229}
230
231void ldmat::ldform ( const mat &A, const vec &D0 ) {
232        int m = A.rows();
233        int n = A.cols();
234        int mn = ( m < n ) ? m : n ;
235
236//      it_assert_debug( A.cols()==dim,"ldmat::ldform A is not compatible" );
237        it_assert_debug ( D0.length() == A.rows(), "ldmat::ldform Vector D must have the length as row count of A" );
238
239        L = concat_vertical ( zeros ( n, n ), diag ( sqrt ( D0 ) ) * A );
240        D = zeros ( n + m );
241
242        //unnecessary big L and D will be made smaller at the end of file
243        vec w = zeros ( n + m );
244
245        double sum, beta, pom;
246
247        int cc = 0;
248        int i = n; // indexovani o 1 niz, nez v matlabu
249        int ii, j, jj;
250        while ( ( i > n - mn - cc ) && ( i > 0 ) ) {
251                i--;
252                sum = 0.0;
253
254                int last_v = m + i - n + cc + 1;
255
256                vec v = zeros ( last_v + 1 ); //prepare v
257                for ( ii = n - cc - 1; ii < m + i + 1; ii++ ) {
258                        sum += L ( ii, i ) * L ( ii, i );
259                        v ( ii - n + cc + 1 ) = L ( ii, i ); //assign v
260                }
261
262                if ( L ( m + i, i ) == 0 )
263                        beta = sqrt ( sum );
264                else
265                        beta = L ( m + i, i ) + sign ( L ( m + i, i ) ) * sqrt ( sum );
266
267                if ( std::fabs ( beta ) < eps ) {
268                        cc++;
269                        L.set_row ( n - cc, L.get_row ( m + i ) );
270                        L.set_row ( m + i, zeros ( L.cols() ) );
271                        D ( m + i ) = 0;
272                        L ( m + i, i ) = 1;
273                        L.set_submatrix ( n - cc, m + i - 1, i, i, 0 );
274                        continue;
275                }
276
277                sum -= v ( last_v ) * v ( last_v );
278                sum /= beta * beta;
279                sum++;
280
281                v /= beta;
282                v ( last_v ) = 1;
283
284                pom = -2.0 / sum;
285                // echo to venca
286
287                for ( j = i; j >= 0; j-- ) {
288                        double w_elem = 0;
289                        for ( ii = n -  cc; ii <= m + i + 1; ii++ )
290                                w_elem += v ( ii - n + cc ) * L ( ii - 1, j );
291                        w ( j ) = w_elem * pom;
292                }
293
294                for ( ii = n - cc - 1; ii <= m + i; ii++ )
295                        for ( jj = 0; jj < i; jj++ )
296                                L ( ii, jj ) += v ( ii - n + cc + 1 ) * w ( jj );
297
298                for ( ii = n - cc - 1; ii < m + i; ii++ )
299                        L ( ii, i ) = 0;
300
301                L ( m + i, i ) += w ( i );
302                D ( m + i ) = L ( m + i, i ) * L ( m + i, i );
303
304                for ( ii = 0; ii <= i; ii++ )
305                        L ( m + i, ii ) /= L ( m + i, i );
306        }
307
308        if ( i >= 0 )
309                for ( ii = 0; ii < i; ii++ ) {
310                        jj = D.length() - 1 - n + ii;
311                        D ( jj ) = 0;
312                        L.set_row ( jj, zeros ( L.cols() ) ); //TODO: set_row accepts Num_T
313                        L ( jj, jj ) = 1;
314                }
315
316        L.del_rows ( 0, m - 1 );
317        D.del ( 0, m - 1 );
318
319        dim = L.rows();
320}
321
322//////// Auxiliary Functions
323
324mat ltuinv ( const mat &L ) {
325        int dim = L.cols();
326        mat Il = eye ( dim );
327        int i, j, k, m;
328        double s;
329
330//Fixme blind transcription of ltuinv.m
331        for ( k = 1; k < ( dim ); k++ ) {
332                for ( i = 0; i < ( dim - k ); i++ ) {
333                        j = i + k; //change in .m 1+1=2, here 0+0+1=1
334                        s = L ( j, i );
335                        for ( m = i + 1; m < ( j ); m++ ) {
336                                s += L ( m, i ) * Il ( j, m );
337                        }
338                        Il ( j, i ) = -s;
339                }
340        }
341
342        return Il;
343}
344
345void dydr ( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
346/********************************************************************
347
348   dydr = dyadic reduction, performs transformation of sum of
349          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
350          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.
351
352   Parameters :
353     r ... pointer to reduced dyad
354     f ... pointer to reducing dyad
355     Dr .. pointer to the weight of reduced dyad
356     Df .. pointer to the weight of reducing dyad
357     R ... pointer to the element of r, which is to be reduced to
358           zero; the corresponding element of f is assumed to be 1.
359     jl .. lower index of the range within which the dyads are
360           modified
361     ju .. upper index of the range within which the dyads are
362           modified
363     kr .. pointer to the coefficient used in the transformation of r
364           rnew = r + kr*f
365     m  .. number of rows of modified matrix (part of which is r)
366  Remark : Constant mzero means machine zero and should be modified
367           according to the precision of particular machine
368
369                                                 V. Peterka 17-7-89
370
371  Added:
372     mx .. number of rows of modified matrix (part of which is f)  -PN
373
374********************************************************************/
375{
376        int j, jm;
377        double kD, r0;
378        double mzero = 2.2e-16;
379        double threshold = 1e-4;
380
381        if ( fabs ( *Dr ) < mzero ) *Dr = 0;
382        r0 = *R;
383        *R = 0.0;
384        kD = *Df;
385        *kr = r0 * *Dr;
386        *Df = kD + r0 * ( *kr );
387        if ( *Df > mzero ) {
388                kD /= *Df;
389                *kr /= *Df;
390        } else {
391                kD = 1.0;
392                *kr = 0.0;
393                if ( *Df < -threshold ) {
394                        it_warning ( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
395                }
396                *Df = 0.0;
397        }
398        *Dr *= kD;
399        jm = mx * jl;
400        for ( j = m * jl; j < m*jh; j += m ) {
401                r[j] -=  r0 * f[jm];
402                f[jm] += *kr * r[j];
403                jm += mx;
404        }
405}
406
407}
Note: See TracBrowser for help on using the browser.