root/library/bdm/math/square_mat.cpp @ 477

Revision 477, 9.4 kB (checked in by mido, 15 years ago)

panove, vite, jak jsem peclivej na upravu kodu.. snad se vam bude libit:) konfigurace je v souboru /system/astylerc

  • Property svn:eol-style set to native
Line 
1
2#include "square_mat.h"
3
4using namespace itpp;
5
6using std::endl;
7
8void fsqmat::opupdt ( const vec &v, double w ) {
9        M += outer_product ( v, v * w );
10};
11mat fsqmat::to_mat() const {
12        return M;
13};
14void fsqmat::mult_sym ( const mat &C ) {
15        M = C * M * C.T();
16};
17void fsqmat::mult_sym_t ( const mat &C ) {
18        M = C.T() * M * C;
19};
20void fsqmat::mult_sym ( const mat &C, fsqmat &U ) const {
21        U.M = ( C * ( M * C.T() ) );
22};
23void fsqmat::mult_sym_t ( const mat &C, fsqmat &U ) const {
24        U.M = ( C.T() * ( M * C ) );
25};
26void fsqmat::inv ( fsqmat &Inv ) const {
27        mat IM = itpp::inv ( M );
28        Inv = IM;
29};
30void fsqmat::clear() {
31        M.clear();
32};
33fsqmat::fsqmat ( const mat &M0 ) : sqmat ( M0.cols() ) {
34        it_assert_debug ( ( M0.cols() == M0.rows() ), "M0 must be square" );
35        M = M0;
36};
37
38//fsqmat::fsqmat() {};
39
40fsqmat::fsqmat ( const int dim0 ) : sqmat ( dim0 ), M ( dim0, dim0 ) {};
41
42std::ostream &operator<< ( std::ostream &os, const fsqmat &ld ) {
43        os << ld.M << endl;
44        return os;
45}
46
47
48ldmat::ldmat ( const mat &exL, const vec &exD ) : sqmat ( exD.length() ) {
49        D = exD;
50        L = exL;
51}
52
53ldmat::ldmat() : sqmat ( 0 ) {}
54
55ldmat::ldmat ( const int dim0 ) : sqmat ( dim0 ), D ( dim0 ), L ( dim0, dim0 ) {}
56
57ldmat::ldmat ( const vec D0 ) : sqmat ( D0.length() ) {
58        D = D0;
59        L = eye ( dim );
60}
61
62ldmat::ldmat ( const mat &V ) : sqmat ( V.cols() ) {
63//TODO check if correct!! Based on heuristic observation of lu()
64
65        it_assert_debug ( dim == V.rows(), "ldmat::ldmat matrix V is not square!" );
66
67        // L and D will be allocated by ldform()
68
69        //Chol is unstable
70        this->ldform ( chol ( V ), ones ( dim ) );
71//      this->ldform(ul(V),ones(dim));
72}
73
74void ldmat::opupdt ( const vec &v,  double w ) {
75        int dim = D.length();
76        double kr;
77        vec r = v;
78        //beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
79        double *Lraw = L._data();
80        double *Draw = D._data();
81        double *rraw = r._data();
82
83        it_assert_debug ( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );
84
85        for ( int i = dim - 1; i >= 0; i-- ) {
86                dydr ( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
87        }
88}
89
90std::ostream &operator<< ( std::ostream &os, const ldmat &ld ) {
91        os << "L:" << ld.L << endl;
92        os << "D:" << ld.D << endl;
93        return os;
94}
95
96mat ldmat::to_mat() const {
97        int dim = D.length();
98        mat V ( dim, dim );
99        double sum;
100        int r, c, cc;
101
102        for ( r = 0; r < dim; r++ ) { //row cycle
103                for ( c = r; c < dim; c++ ) {
104                        //column cycle, using symmetricity => c=r!
105                        sum = 0.0;
106                        for ( cc = c; cc < dim; cc++ ) { //cycle over the remaining part of the vector
107                                sum += L ( cc, r ) * D ( cc ) * L ( cc, c );
108                                //here L(cc,r) = L(r,cc)';
109                        }
110                        V ( r, c ) = sum;
111                        // symmetricity
112                        if ( r != c ) {
113                                V ( c, r ) = sum;
114                        };
115                }
116        }
117        mat V2 = L.transpose() * diag ( D ) * L;
118        return V2;
119}
120
121
122void ldmat::add ( const ldmat &ld2, double w ) {
123        int dim = D.length();
124
125        it_assert_debug ( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );
126
127        //Fixme can be done more efficiently either via dydr or ldform
128        for ( int r = 0; r < dim; r++ ) {
129                // Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
130                this->opupdt ( ld2.L.get_row ( r ), w*ld2.D ( r ) );
131        }
132}
133
134void ldmat::clear() {
135        L.clear();
136        for ( int i = 0; i < L.cols(); i++ ) {
137                L ( i, i ) = 1;
138        };
139        D.clear();
140}
141
142void ldmat::inv ( ldmat &Inv ) const {
143        Inv.clear();   //Inv = zero in LD
144        mat U = ltuinv ( L );
145
146        Inv.ldform ( U.transpose(), 1.0 / D );
147}
148
149void ldmat::mult_sym ( const mat &C ) {
150        mat A = L * C.T();
151        this->ldform ( A, D );
152}
153
154void ldmat::mult_sym_t ( const mat &C ) {
155        mat A = L * C;
156        this->ldform ( A, D );
157}
158
159void ldmat::mult_sym ( const mat &C, ldmat &U ) const {
160        mat A = L * C.T(); //could be done more efficiently using BLAS
161        U.ldform ( A, D );
162}
163
164void ldmat::mult_sym_t ( const mat &C, ldmat &U ) const {
165        mat A = L * C;
166        /*      vec nD=zeros(U.rows());
167                nD.replace_mid(0, D); //I case that D < nD*/
168        U.ldform ( A, D );
169}
170
171
172double ldmat::logdet() const {
173        double ldet = 0.0;
174        int i;
175// sum logarithms of diagobal elements
176        for ( i = 0; i < D.length(); i++ ) {
177                ldet += log ( D ( i ) );
178        };
179        return ldet;
180}
181
182double ldmat::qform ( const vec &v ) const {
183        double x = 0.0, sum;
184        int i, j;
185
186        for ( i = 0; i < D.length(); i++ ) { //rows of L
187                sum = 0.0;
188                for ( j = 0; j <= i; j++ ) {
189                        sum += L ( i, j ) * v ( j );
190                }
191                x += D ( i ) * sum * sum;
192        };
193        return x;
194}
195
196double ldmat::invqform ( const vec &v ) const {
197        double x = 0.0;
198        int i;
199        vec pom ( v.length() );
200
201        backward_substitution ( L.T(), v, pom );
202
203        for ( i = 0; i < D.length(); i++ ) { //rows of L
204                x += pom ( i ) * pom ( i ) / D ( i );
205        };
206        return x;
207}
208
209ldmat& ldmat::operator *= ( double x ) {
210        D *= x;
211        return *this;
212}
213
214vec ldmat::sqrt_mult ( const vec &x ) const {
215        int i, j;
216        vec res ( dim );
217        //double sum;
218        for ( i = 0; i < dim; i++ ) {//for each element of result
219                res ( i ) = 0.0;
220                for ( j = i; j < dim; j++ ) {//sum D(j)*L(:,i).*x
221                        res ( i ) += sqrt ( D ( j ) ) * L ( j, i ) * x ( j );
222                }
223        }
224//      vec res2 = L.transpose()*diag( sqrt( D ) )*x;
225        return res;
226}
227
228void ldmat::ldform ( const mat &A, const vec &D0 ) {
229        int m = A.rows();
230        int n = A.cols();
231        int mn = ( m < n ) ? m : n ;
232
233//      it_assert_debug( A.cols()==dim,"ldmat::ldform A is not compatible" );
234        it_assert_debug ( D0.length() == A.rows(), "ldmat::ldform Vector D must have the length as row count of A" );
235
236        L = concat_vertical ( zeros ( n, n ), diag ( sqrt ( D0 ) ) * A );
237        D = zeros ( n + m );
238
239        //unnecessary big L and D will be made smaller at the end of file
240        vec w = zeros ( n + m );
241
242        double sum, beta, pom;
243
244        int cc = 0;
245        int i = n; // indexovani o 1 niz, nez v matlabu
246        int ii, j, jj;
247        while ( ( i > n - mn - cc ) && ( i > 0 ) ) {
248                i--;
249                sum = 0.0;
250
251                int last_v = m + i - n + cc + 1;
252
253                vec v = zeros ( last_v + 1 ); //prepare v
254                for ( ii = n - cc - 1; ii < m + i + 1; ii++ ) {
255                        sum += L ( ii, i ) * L ( ii, i );
256                        v ( ii - n + cc + 1 ) = L ( ii, i ); //assign v
257                }
258
259                if ( L ( m + i, i ) == 0 )
260                        beta = sqrt ( sum );
261                else
262                        beta = L ( m + i, i ) + sign ( L ( m + i, i ) ) * sqrt ( sum );
263
264                if ( std::fabs ( beta ) < eps ) {
265                        cc++;
266                        L.set_row ( n - cc, L.get_row ( m + i ) );
267                        L.set_row ( m + i, zeros ( L.cols() ) );
268                        D ( m + i ) = 0;
269                        L ( m + i, i ) = 1;
270                        L.set_submatrix ( n - cc, m + i - 1, i, i, 0 );
271                        continue;
272                }
273
274                sum -= v ( last_v ) * v ( last_v );
275                sum /= beta * beta;
276                sum++;
277
278                v /= beta;
279                v ( last_v ) = 1;
280
281                pom = -2.0 / sum;
282                // echo to venca
283
284                for ( j = i; j >= 0; j-- ) {
285                        double w_elem = 0;
286                        for ( ii = n -  cc; ii <= m + i + 1; ii++ )
287                                w_elem += v ( ii - n + cc ) * L ( ii - 1, j );
288                        w ( j ) = w_elem * pom;
289                }
290
291                for ( ii = n - cc - 1; ii <= m + i; ii++ )
292                        for ( jj = 0; jj < i; jj++ )
293                                L ( ii, jj ) += v ( ii - n + cc + 1 ) * w ( jj );
294
295                for ( ii = n - cc - 1; ii < m + i; ii++ )
296                        L ( ii, i ) = 0;
297
298                L ( m + i, i ) += w ( i );
299                D ( m + i ) = L ( m + i, i ) * L ( m + i, i );
300
301                for ( ii = 0; ii <= i; ii++ )
302                        L ( m + i, ii ) /= L ( m + i, i );
303        }
304
305        if ( i >= 0 )
306                for ( ii = 0; ii < i; ii++ ) {
307                        jj = D.length() - 1 - n + ii;
308                        D ( jj ) = 0;
309                        L.set_row ( jj, zeros ( L.cols() ) ); //TODO: set_row accepts Num_T
310                        L ( jj, jj ) = 1;
311                }
312
313        L.del_rows ( 0, m - 1 );
314        D.del ( 0, m - 1 );
315
316        dim = L.rows();
317}
318
319//////// Auxiliary Functions
320
321mat ltuinv ( const mat &L ) {
322        int dim = L.cols();
323        mat Il = eye ( dim );
324        int i, j, k, m;
325        double s;
326
327//Fixme blind transcription of ltuinv.m
328        for ( k = 1; k < ( dim ); k++ ) {
329                for ( i = 0; i < ( dim - k ); i++ ) {
330                        j = i + k; //change in .m 1+1=2, here 0+0+1=1
331                        s = L ( j, i );
332                        for ( m = i + 1; m < ( j ); m++ ) {
333                                s += L ( m, i ) * Il ( j, m );
334                        }
335                        Il ( j, i ) = -s;
336                }
337        }
338
339        return Il;
340}
341
342void dydr ( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
343/********************************************************************
344
345   dydr = dyadic reduction, performs transformation of sum of
346          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
347          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.
348
349   Parameters :
350     r ... pointer to reduced dyad
351     f ... pointer to reducing dyad
352     Dr .. pointer to the weight of reduced dyad
353     Df .. pointer to the weight of reducing dyad
354     R ... pointer to the element of r, which is to be reduced to
355           zero; the corresponding element of f is assumed to be 1.
356     jl .. lower index of the range within which the dyads are
357           modified
358     ju .. upper index of the range within which the dyads are
359           modified
360     kr .. pointer to the coefficient used in the transformation of r
361           rnew = r + kr*f
362     m  .. number of rows of modified matrix (part of which is r)
363  Remark : Constant mzero means machine zero and should be modified
364           according to the precision of particular machine
365
366                                                 V. Peterka 17-7-89
367
368  Added:
369     mx .. number of rows of modified matrix (part of which is f)  -PN
370
371********************************************************************/
372{
373        int j, jm;
374        double kD, r0;
375        double mzero = 2.2e-16;
376        double threshold = 1e-4;
377
378        if ( fabs ( *Dr ) < mzero ) *Dr = 0;
379        r0 = *R;
380        *R = 0.0;
381        kD = *Df;
382        *kr = r0 * *Dr;
383        *Df = kD + r0 * ( *kr );
384        if ( *Df > mzero ) {
385                kD /= *Df;
386                *kr /= *Df;
387        } else {
388                kD = 1.0;
389                *kr = 0.0;
390                if ( *Df < -threshold ) {
391                        it_warning ( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
392                }
393                *Df = 0.0;
394        }
395        *Dr *= kD;
396        jm = mx * jl;
397        for ( j = m * jl; j < m*jh; j += m ) {
398                r[j] -=  r0 * f[jm];
399                f[jm] += *kr * r[j];
400                jm += mx;
401        }
402}
Note: See TracBrowser for help on using the browser.