root/library/bdm/math/square_mat.cpp @ 384

Revision 384, 8.8 kB (checked in by mido, 15 years ago)

possibly broken?

  • Property svn:eol-style set to native
Line 
1
2#include "square_mat.h"
3
4using namespace itpp;
5
6using std::endl;
7
8void fsqmat::opupdt ( const vec &v, double w ) {M+=outer_product ( v,v*w );};
9mat fsqmat::to_mat() const {return M;};
10void fsqmat::mult_sym ( const mat &C) {M=C *M*C.T();};
11void fsqmat::mult_sym_t ( const mat &C) {M=C.T() *M*C;};
12void fsqmat::mult_sym ( const mat &C, fsqmat &U) const { U.M = ( C *(M*C.T()) );};
13void fsqmat::mult_sym_t ( const mat &C, fsqmat &U) const { U.M = ( C.T() *(M*C) );};
14void fsqmat::inv ( fsqmat &Inv ) {mat IM = itpp::inv ( M ); Inv=IM;};
15void fsqmat::clear() {M.clear();};
16fsqmat::fsqmat ( const mat &M0 ) : sqmat(M0.cols())
17{
18        it_assert_debug ( ( M0.cols() ==M0.rows() ),"M0 must be square" );
19        M=M0;
20};
21
22//fsqmat::fsqmat() {};
23
24fsqmat::fsqmat(const int dim0): sqmat(dim0), M(dim0,dim0) {};
25
26std::ostream &operator<< ( std::ostream &os, const fsqmat &ld ) {
27        os << ld.M << endl;
28        return os;
29}
30
31
32ldmat::ldmat( const mat &exL, const vec &exD ) : sqmat(exD.length()) {
33        D = exD;
34        L = exL;
35}
36
37ldmat::ldmat() :sqmat(0) {}
38
39ldmat::ldmat(const int dim0): sqmat(dim0), D(dim0),L(dim0,dim0) {}
40
41ldmat::ldmat(const vec D0):sqmat(D0.length()) {
42        D = D0;
43        L = eye(dim);
44}
45
46ldmat::ldmat( const mat &V ):sqmat(V.cols()) {
47//TODO check if correct!! Based on heuristic observation of lu()
48
49        it_assert_debug( dim == V.rows(),"ldmat::ldmat matrix V is not square!" );
50               
51        // L and D will be allocated by ldform()
52       
53        //Chol is unstable
54        this->ldform(chol(V),ones(dim));
55//      this->ldform(ul(V),ones(dim));
56}
57
58void ldmat::opupdt( const vec &v,  double w ) {
59        int dim = D.length();
60        double kr;
61        vec r = v;
62        //beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
63        double *Lraw = L._data();
64        double *Draw = D._data();
65        double *rraw = r._data();
66
67        it_assert_debug( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );
68
69        for ( int i = dim - 1; i >= 0; i-- ) {
70                dydr( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
71        }
72}
73
74std::ostream &operator<< ( std::ostream &os, const ldmat &ld ) {
75        os << "L:" << ld.L << endl;
76        os << "D:" << ld.D << endl;
77        return os;
78}
79
80mat ldmat::to_mat() const {
81        int dim = D.length();
82        mat V( dim, dim );
83        double sum;
84        int r, c, cc;
85
86        for ( r = 0;r < dim;r++ ) { //row cycle
87                for ( c = r;c < dim;c++ ) {
88                        //column cycle, using symmetricity => c=r!
89                        sum = 0.0;
90                        for ( cc = c;cc < dim;cc++ ) { //cycle over the remaining part of the vector
91                                sum += L( cc, r ) * D( cc ) * L( cc, c );
92                                //here L(cc,r) = L(r,cc)';
93                        }
94                        V( r, c ) = sum;
95                        // symmetricity
96                        if ( r != c ) {V( c, r ) = sum;};
97                }
98        }
99        mat V2 = L.transpose()*diag( D )*L;
100        return V2;
101}
102
103
104void ldmat::add( const ldmat &ld2, double w ) {
105        int dim = D.length();
106
107        it_assert_debug( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );
108
109        //Fixme can be done more efficiently either via dydr or ldform
110        for ( int r = 0; r < dim; r++ ) {
111                // Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
112                this->opupdt( ld2.L.get_row( r ), w*ld2.D( r ) );
113        }
114}
115
116void ldmat::clear(){L.clear(); for ( int i=0;i<L.cols();i++ ){L( i,i )=1;}; D.clear();}
117
118void ldmat::inv( ldmat &Inv ) const {
119        Inv.clear();   //Inv = zero in LD
120        mat U = ltuinv( L );
121
122        Inv.ldform( U.transpose(), 1.0 / D );
123}
124
125void ldmat::mult_sym( const mat &C) {
126        mat A = L*C.T();
127        this->ldform(A,D);
128}
129
130void ldmat::mult_sym_t( const mat &C) {
131        mat A = L*C;
132        this->ldform(A,D);
133}
134
135void ldmat::mult_sym( const mat &C, ldmat &U) const {
136        mat A=L*C.T(); //could be done more efficiently using BLAS
137        U.ldform(A,D);
138}
139
140void ldmat::mult_sym_t( const mat &C, ldmat &U) const {
141        mat A=L*C;
142/*      vec nD=zeros(U.rows());
143        nD.replace_mid(0, D); //I case that D < nD*/
144        U.ldform(A,D);
145}
146
147
148double ldmat::logdet() const {
149        double ldet = 0.0;
150        int i;
151// sum logarithms of diagobal elements
152        for ( i=0; i<D.length(); i++ ){ldet+=log( D( i ) );};
153        return ldet;
154}
155
156double ldmat::qform( const vec &v ) const {
157        double x = 0.0, sum;
158        int i,j;
159
160        for ( i=0; i<D.length(); i++ ) { //rows of L
161                sum = 0.0;
162                for ( j=0; j<=i; j++ ){sum+=L( i,j )*v( j );}
163                x +=D( i )*sum*sum;
164        };
165        return x;
166}
167
168double ldmat::invqform( const vec &v ) const {
169        double x = 0.0;
170        int i;
171        vec pom(v.length());
172       
173        backward_substitution(L.T(),v,pom);
174       
175        for ( i=0; i<D.length(); i++ ) { //rows of L
176                x +=pom(i)*pom(i)/D(i);
177        };
178        return x;
179}
180
181ldmat& ldmat::operator *= ( double x ) {
182        D*=x;
183        return *this;
184}
185
186vec ldmat::sqrt_mult( const vec &x ) const {
187        int i,j;
188        vec res( dim );
189        //double sum;
190        for ( i=0;i<dim;i++ ) {//for each element of result
191                res( i ) = 0.0;
192                for ( j=i;j<dim;j++ ) {//sum D(j)*L(:,i).*x
193                        res( i ) += sqrt( D( j ) )*L( j,i )*x( j );
194                }
195        }
196//      vec res2 = L.transpose()*diag( sqrt( D ) )*x;
197        return res;
198}
199
200void ldmat::ldform(const mat &A,const vec &D0 )
201{
202        int m = A.rows();
203        int n = A.cols();
204        int mn = (m<n) ? m :n ;
205
206//      it_assert_debug( A.cols()==dim,"ldmat::ldform A is not compatible" );
207        it_assert_debug( D0.length()==A.rows(),"ldmat::ldform Vector D must have the length as row count of A" );
208
209        L=concat_vertical( zeros( n,n ), diag( sqrt( D0 ) )*A );
210        D=zeros( n+m );
211
212        //unnecessary big L and D will be made smaller at the end of file       
213        vec w=zeros( n+m );
214   
215        double sum, beta, pom;
216
217        int cc=0;
218        int i=n; // indexovani o 1 niz, nez v matlabu
219        int ii,j,jj;
220        while ( (i>n-mn-cc) && (i>0) ) 
221        {
222                i--;
223                sum = 0.0;
224
225                int last_v = m+i-n+cc+1;
226       
227                vec v = zeros( last_v + 1 ); //prepare v
228                for ( ii=n-cc-1;ii<m+i+1;ii++ ) 
229                {
230                        sum+= L( ii,i )*L( ii,i );
231                        v( ii-n+cc+1 )=L( ii,i ); //assign v
232                }
233
234                if ( L( m+i,i )==0 ) 
235                        beta = sqrt( sum );             
236                else
237                        beta = L( m+i,i )+sign( L( m+i,i ) )*sqrt( sum );               
238     
239                if ( std::fabs( beta )<eps )
240                {
241                        cc++;
242                        L.set_row( n-cc, L.get_row( m+i ) );
243                        L.set_row( m+i,zeros(L.cols()) );
244                        D( m+i )=0; L( m+i,i )=1;
245                        L.set_submatrix( n-cc,m+i-1,i,i,0 );
246                        continue;
247                }
248
249                sum-=v(last_v)*v(last_v);
250                sum/=beta*beta;
251                sum++;
252
253                v/=beta;
254                v(last_v)=1;
255
256                pom=-2.0/sum;
257                // echo to venca   
258
259                for ( j=i;j>=0;j-- ) 
260                {
261                        double w_elem = 0;                     
262                        for ( ii=n-     cc;ii<=m+i+1;ii++ ) 
263                                w_elem+= v( ii-n+cc )*L( ii-1,j );                     
264                        w(j)=w_elem*pom;
265                }
266
267                for ( ii=n-cc-1;ii<=m+i;ii++ ) 
268                        for ( jj=0;jj<i;jj++ ) 
269                                L( ii,jj )+= v( ii-n+cc+1)*w( jj );
270
271                for ( ii=n-cc-1;ii<m+i;ii++ )
272                        L( ii,i )= 0;
273
274                L( m+i,i )+=w( i );
275                D( m+i )=L( m+i,i )*L( m+i,i );
276
277                for ( ii=0;ii<=i;ii++ ) 
278                        L( m+i,ii )/=L( m+i,i );               
279        }
280
281        if ( i>=0 )
282                for ( ii=0;ii<i;ii++ ) 
283                {
284                        jj = D.length()-1-n+ii;
285                        D(jj) = 0;
286                        L.set_row(jj,zeros(L.cols())); //TODO: set_row accepts Num_T
287                        L(jj,jj)=1;
288                }
289
290        L.del_rows(0,m-1);
291        D.del(0,m-1);
292       
293        dim = L.rows();
294}
295
296//////// Auxiliary Functions
297
298mat ltuinv( const mat &L ) {
299        int dim = L.cols();
300        mat Il = eye( dim );
301        int i, j, k, m;
302        double s;
303
304//Fixme blind transcription of ltuinv.m
305        for ( k = 1; k < ( dim );k++ ) {
306                for ( i = 0; i < ( dim - k );i++ ) {
307                        j = i + k; //change in .m 1+1=2, here 0+0+1=1
308                        s = L( j, i );
309                        for ( m = i + 1; m < ( j ); m++ ) {
310                                s += L( m, i ) * Il( j, m );
311                        }
312                        Il( j, i ) = -s;
313                }
314        }
315
316        return Il;
317}
318
319void dydr( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
320/********************************************************************
321
322   dydr = dyadic reduction, performs transformation of sum of
323          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
324          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.
325
326   Parameters :
327     r ... pointer to reduced dyad
328     f ... pointer to reducing dyad
329     Dr .. pointer to the weight of reduced dyad
330     Df .. pointer to the weight of reducing dyad
331     R ... pointer to the element of r, which is to be reduced to
332           zero; the corresponding element of f is assumed to be 1.
333     jl .. lower index of the range within which the dyads are
334           modified
335     ju .. upper index of the range within which the dyads are
336           modified
337     kr .. pointer to the coefficient used in the transformation of r
338           rnew = r + kr*f
339     m  .. number of rows of modified matrix (part of which is r)
340  Remark : Constant mzero means machine zero and should be modified
341           according to the precision of particular machine
342
343                                                 V. Peterka 17-7-89
344
345  Added:
346     mx .. number of rows of modified matrix (part of which is f)  -PN
347
348********************************************************************/
349{
350        int j, jm;
351        double kD, r0;
352        double mzero = 2.2e-16;
353        double threshold = 1e-4;
354
355        if ( fabs( *Dr ) < mzero ) *Dr = 0;
356        r0 = *R;
357        *R = 0.0;
358        kD = *Df;
359        *kr = r0 * *Dr;
360        *Df = kD + r0 * ( *kr );
361        if ( *Df > mzero ) {
362                kD /= *Df;
363                *kr /= *Df;
364        } else {
365                kD = 1.0;
366                *kr = 0.0;
367                if ( *Df < -threshold ) {
368                        it_warning( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
369                }
370                *Df = 0.0;
371        }
372        *Dr *= kD;
373        jm = mx * jl;
374        for ( j = m * jl; j < m*jh; j += m ) {
375                r[j] -=  r0 * f[jm];
376                f[jm] += *kr * r[j];
377                jm += mx;
378        }
379}
Note: See TracBrowser for help on using the browser.