Changeset 477 for library/bdm/math/square_mat.cpp
- Timestamp:
- 08/05/09 14:40:03 (15 years ago)
- Files:
-
- 1 modified
Legend:
- Unmodified
- Added
- Removed
-
library/bdm/math/square_mat.cpp
r433 r477 6 6 using std::endl; 7 7 8 void fsqmat::opupdt ( const vec &v, double w ) {M+=outer_product ( v,v*w );}; 9 mat fsqmat::to_mat() const {return M;}; 10 void fsqmat::mult_sym ( const mat &C) {M=C *M*C.T();}; 11 void fsqmat::mult_sym_t ( const mat &C) {M=C.T() *M*C;}; 12 void fsqmat::mult_sym ( const mat &C, fsqmat &U) const { U.M = ( C *(M*C.T()) );}; 13 void fsqmat::mult_sym_t ( const mat &C, fsqmat &U) const { U.M = ( C.T() *(M*C) );}; 14 void fsqmat::inv ( fsqmat &Inv ) const {mat IM = itpp::inv ( M ); Inv=IM;}; 15 void fsqmat::clear() {M.clear();}; 16 fsqmat::fsqmat ( const mat &M0 ) : sqmat(M0.cols()) 17 { 18 it_assert_debug ( ( M0.cols() ==M0.rows() ),"M0 must be square" ); 19 M=M0; 8 void fsqmat::opupdt ( const vec &v, double w ) { 9 M += outer_product ( v, v * w ); 10 }; 11 mat fsqmat::to_mat() const { 12 return M; 13 }; 14 void fsqmat::mult_sym ( const mat &C ) { 15 M = C * M * C.T(); 16 }; 17 void fsqmat::mult_sym_t ( const mat &C ) { 18 M = C.T() * M * C; 19 }; 20 void fsqmat::mult_sym ( const mat &C, fsqmat &U ) const { 21 U.M = ( C * ( M * C.T() ) ); 22 }; 23 void fsqmat::mult_sym_t ( const mat &C, fsqmat &U ) const { 24 U.M = ( C.T() * ( M * C ) ); 25 }; 26 void fsqmat::inv ( fsqmat &Inv ) const { 27 mat IM = itpp::inv ( M ); 28 Inv = IM; 29 }; 30 void fsqmat::clear() { 31 M.clear(); 32 }; 33 fsqmat::fsqmat ( const mat &M0 ) : sqmat ( M0.cols() ) { 34 it_assert_debug ( ( M0.cols() == M0.rows() ), "M0 must be square" ); 35 M = M0; 20 36 }; 21 37 22 38 //fsqmat::fsqmat() {}; 23 39 24 fsqmat::fsqmat (const int dim0): sqmat(dim0), M(dim0,dim0) {};40 fsqmat::fsqmat ( const int dim0 ) : sqmat ( dim0 ), M ( dim0, dim0 ) {}; 25 41 26 42 std::ostream &operator<< ( std::ostream &os, const fsqmat &ld ) { … … 30 46 31 47 32 ldmat::ldmat ( const mat &exL, const vec &exD ) : sqmat(exD.length()) {48 ldmat::ldmat ( const mat &exL, const vec &exD ) : sqmat ( exD.length() ) { 33 49 D = exD; 34 50 L = exL; 35 51 } 36 52 37 ldmat::ldmat() : sqmat(0) {}38 39 ldmat::ldmat (const int dim0): sqmat(dim0), D(dim0),L(dim0,dim0) {}40 41 ldmat::ldmat (const vec D0):sqmat(D0.length()) {53 ldmat::ldmat() : sqmat ( 0 ) {} 54 55 ldmat::ldmat ( const int dim0 ) : sqmat ( dim0 ), D ( dim0 ), L ( dim0, dim0 ) {} 56 57 ldmat::ldmat ( const vec D0 ) : sqmat ( D0.length() ) { 42 58 D = D0; 43 L = eye (dim);44 } 45 46 ldmat::ldmat ( const mat &V ):sqmat(V.cols()) {59 L = eye ( dim ); 60 } 61 62 ldmat::ldmat ( const mat &V ) : sqmat ( V.cols() ) { 47 63 //TODO check if correct!! Based on heuristic observation of lu() 48 64 49 it_assert_debug ( dim == V.rows(),"ldmat::ldmat matrix V is not square!" );50 65 it_assert_debug ( dim == V.rows(), "ldmat::ldmat matrix V is not square!" ); 66 51 67 // L and D will be allocated by ldform() 52 68 53 69 //Chol is unstable 54 this->ldform (chol(V),ones(dim));70 this->ldform ( chol ( V ), ones ( dim ) ); 55 71 // this->ldform(ul(V),ones(dim)); 56 72 } 57 73 58 void ldmat::opupdt ( const vec &v, double w ) {74 void ldmat::opupdt ( const vec &v, double w ) { 59 75 int dim = D.length(); 60 76 double kr; … … 65 81 double *rraw = r._data(); 66 82 67 it_assert_debug ( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );83 it_assert_debug ( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." ); 68 84 69 85 for ( int i = dim - 1; i >= 0; i-- ) { 70 dydr ( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );86 dydr ( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim ); 71 87 } 72 88 } … … 80 96 mat ldmat::to_mat() const { 81 97 int dim = D.length(); 82 mat V ( dim, dim );98 mat V ( dim, dim ); 83 99 double sum; 84 100 int r, c, cc; 85 101 86 for ( r = 0; r < dim;r++ ) { //row cycle87 for ( c = r; c < dim;c++ ) {102 for ( r = 0; r < dim; r++ ) { //row cycle 103 for ( c = r; c < dim; c++ ) { 88 104 //column cycle, using symmetricity => c=r! 89 105 sum = 0.0; 90 for ( cc = c; cc < dim;cc++ ) { //cycle over the remaining part of the vector91 sum += L ( cc, r ) * D( cc ) * L( cc, c );106 for ( cc = c; cc < dim; cc++ ) { //cycle over the remaining part of the vector 107 sum += L ( cc, r ) * D ( cc ) * L ( cc, c ); 92 108 //here L(cc,r) = L(r,cc)'; 93 109 } 94 V ( r, c ) = sum;110 V ( r, c ) = sum; 95 111 // symmetricity 96 if ( r != c ) {V( c, r ) = sum;}; 97 } 98 } 99 mat V2 = L.transpose()*diag( D )*L; 112 if ( r != c ) { 113 V ( c, r ) = sum; 114 }; 115 } 116 } 117 mat V2 = L.transpose() * diag ( D ) * L; 100 118 return V2; 101 119 } 102 120 103 121 104 void ldmat::add ( const ldmat &ld2, double w ) {122 void ldmat::add ( const ldmat &ld2, double w ) { 105 123 int dim = D.length(); 106 124 107 it_assert_debug ( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );125 it_assert_debug ( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" ); 108 126 109 127 //Fixme can be done more efficiently either via dydr or ldform 110 128 for ( int r = 0; r < dim; r++ ) { 111 129 // Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D 112 this->opupdt( ld2.L.get_row( r ), w*ld2.D( r ) ); 113 } 114 } 115 116 void ldmat::clear(){L.clear(); for ( int i=0;i<L.cols();i++ ){L( i,i )=1;}; D.clear();} 117 118 void ldmat::inv( ldmat &Inv ) const { 130 this->opupdt ( ld2.L.get_row ( r ), w*ld2.D ( r ) ); 131 } 132 } 133 134 void ldmat::clear() { 135 L.clear(); 136 for ( int i = 0; i < L.cols(); i++ ) { 137 L ( i, i ) = 1; 138 }; 139 D.clear(); 140 } 141 142 void ldmat::inv ( ldmat &Inv ) const { 119 143 Inv.clear(); //Inv = zero in LD 120 mat U = ltuinv ( L );121 122 Inv.ldform ( U.transpose(), 1.0 / D );123 } 124 125 void ldmat::mult_sym ( const mat &C) {126 mat A = L *C.T();127 this->ldform (A,D);128 } 129 130 void ldmat::mult_sym_t ( const mat &C) {131 mat A = L *C;132 this->ldform (A,D);133 } 134 135 void ldmat::mult_sym ( const mat &C, ldmat &U) const {136 mat A =L*C.T(); //could be done more efficiently using BLAS137 U.ldform (A,D);138 } 139 140 void ldmat::mult_sym_t ( const mat &C, ldmat &U) const {141 mat A =L*C;142 /* vec nD=zeros(U.rows());143 nD.replace_mid(0, D); //I case that D < nD*/144 U.ldform (A,D);144 mat U = ltuinv ( L ); 145 146 Inv.ldform ( U.transpose(), 1.0 / D ); 147 } 148 149 void ldmat::mult_sym ( const mat &C ) { 150 mat A = L * C.T(); 151 this->ldform ( A, D ); 152 } 153 154 void ldmat::mult_sym_t ( const mat &C ) { 155 mat A = L * C; 156 this->ldform ( A, D ); 157 } 158 159 void ldmat::mult_sym ( const mat &C, ldmat &U ) const { 160 mat A = L * C.T(); //could be done more efficiently using BLAS 161 U.ldform ( A, D ); 162 } 163 164 void ldmat::mult_sym_t ( const mat &C, ldmat &U ) const { 165 mat A = L * C; 166 /* vec nD=zeros(U.rows()); 167 nD.replace_mid(0, D); //I case that D < nD*/ 168 U.ldform ( A, D ); 145 169 } 146 170 … … 150 174 int i; 151 175 // sum logarithms of diagobal elements 152 for ( i=0; i<D.length(); i++ ){ldet+=log( D( i ) );}; 176 for ( i = 0; i < D.length(); i++ ) { 177 ldet += log ( D ( i ) ); 178 }; 153 179 return ldet; 154 180 } 155 181 156 double ldmat::qform ( const vec &v ) const {182 double ldmat::qform ( const vec &v ) const { 157 183 double x = 0.0, sum; 158 int i, j;159 160 for ( i =0; i<D.length(); i++ ) { //rows of L184 int i, j; 185 186 for ( i = 0; i < D.length(); i++ ) { //rows of L 161 187 sum = 0.0; 162 for ( j=0; j<=i; j++ ){sum+=L( i,j )*v( j );} 163 x +=D( i )*sum*sum; 188 for ( j = 0; j <= i; j++ ) { 189 sum += L ( i, j ) * v ( j ); 190 } 191 x += D ( i ) * sum * sum; 164 192 }; 165 193 return x; 166 194 } 167 195 168 double ldmat::invqform ( const vec &v ) const {196 double ldmat::invqform ( const vec &v ) const { 169 197 double x = 0.0; 170 198 int i; 171 vec pom (v.length());172 173 backward_substitution (L.T(),v,pom);174 175 for ( i =0; i<D.length(); i++ ) { //rows of L176 x += pom(i)*pom(i)/D(i);199 vec pom ( v.length() ); 200 201 backward_substitution ( L.T(), v, pom ); 202 203 for ( i = 0; i < D.length(); i++ ) { //rows of L 204 x += pom ( i ) * pom ( i ) / D ( i ); 177 205 }; 178 206 return x; … … 180 208 181 209 ldmat& ldmat::operator *= ( double x ) { 182 D *=x;210 D *= x; 183 211 return *this; 184 212 } 185 213 186 vec ldmat::sqrt_mult ( const vec &x ) const {187 int i, j;188 vec res ( dim );214 vec ldmat::sqrt_mult ( const vec &x ) const { 215 int i, j; 216 vec res ( dim ); 189 217 //double sum; 190 for ( i =0;i<dim;i++ ) {//for each element of result191 res ( i ) = 0.0;192 for ( j =i;j<dim;j++ ) {//sum D(j)*L(:,i).*x193 res ( i ) += sqrt( D( j ) )*L( j,i )*x( j );218 for ( i = 0; i < dim; i++ ) {//for each element of result 219 res ( i ) = 0.0; 220 for ( j = i; j < dim; j++ ) {//sum D(j)*L(:,i).*x 221 res ( i ) += sqrt ( D ( j ) ) * L ( j, i ) * x ( j ); 194 222 } 195 223 } … … 198 226 } 199 227 200 void ldmat::ldform(const mat &A,const vec &D0 ) 201 { 228 void ldmat::ldform ( const mat &A, const vec &D0 ) { 202 229 int m = A.rows(); 203 230 int n = A.cols(); 204 int mn = ( m<n) ? m :n ;231 int mn = ( m < n ) ? m : n ; 205 232 206 233 // it_assert_debug( A.cols()==dim,"ldmat::ldform A is not compatible" ); 207 it_assert_debug ( D0.length()==A.rows(),"ldmat::ldform Vector D must have the length as row count of A" );208 209 L =concat_vertical( zeros( n,n ), diag( sqrt( D0 ) )*A );210 D =zeros( n+m );211 212 //unnecessary big L and D will be made smaller at the end of file 213 vec w =zeros( n+m );214 234 it_assert_debug ( D0.length() == A.rows(), "ldmat::ldform Vector D must have the length as row count of A" ); 235 236 L = concat_vertical ( zeros ( n, n ), diag ( sqrt ( D0 ) ) * A ); 237 D = zeros ( n + m ); 238 239 //unnecessary big L and D will be made smaller at the end of file 240 vec w = zeros ( n + m ); 241 215 242 double sum, beta, pom; 216 243 217 int cc=0; 218 int i=n; // indexovani o 1 niz, nez v matlabu 219 int ii,j,jj; 220 while ( (i>n-mn-cc) && (i>0) ) 221 { 244 int cc = 0; 245 int i = n; // indexovani o 1 niz, nez v matlabu 246 int ii, j, jj; 247 while ( ( i > n - mn - cc ) && ( i > 0 ) ) { 222 248 i--; 223 249 sum = 0.0; 224 250 225 int last_v = m+i-n+cc+1; 226 227 vec v = zeros( last_v + 1 ); //prepare v 228 for ( ii=n-cc-1;ii<m+i+1;ii++ ) 229 { 230 sum+= L( ii,i )*L( ii,i ); 231 v( ii-n+cc+1 )=L( ii,i ); //assign v 232 } 233 234 if ( L( m+i,i )==0 ) 235 beta = sqrt( sum ); 251 int last_v = m + i - n + cc + 1; 252 253 vec v = zeros ( last_v + 1 ); //prepare v 254 for ( ii = n - cc - 1; ii < m + i + 1; ii++ ) { 255 sum += L ( ii, i ) * L ( ii, i ); 256 v ( ii - n + cc + 1 ) = L ( ii, i ); //assign v 257 } 258 259 if ( L ( m + i, i ) == 0 ) 260 beta = sqrt ( sum ); 236 261 else 237 beta = L( m+i,i )+sign( L( m+i,i ) )*sqrt( sum ); 238 239 if ( std::fabs( beta )<eps ) 240 { 262 beta = L ( m + i, i ) + sign ( L ( m + i, i ) ) * sqrt ( sum ); 263 264 if ( std::fabs ( beta ) < eps ) { 241 265 cc++; 242 L.set_row( n-cc, L.get_row( m+i ) ); 243 L.set_row( m+i,zeros(L.cols()) ); 244 D( m+i )=0; L( m+i,i )=1; 245 L.set_submatrix( n-cc,m+i-1,i,i,0 ); 266 L.set_row ( n - cc, L.get_row ( m + i ) ); 267 L.set_row ( m + i, zeros ( L.cols() ) ); 268 D ( m + i ) = 0; 269 L ( m + i, i ) = 1; 270 L.set_submatrix ( n - cc, m + i - 1, i, i, 0 ); 246 271 continue; 247 272 } 248 273 249 sum -=v(last_v)*v(last_v);250 sum /=beta*beta;274 sum -= v ( last_v ) * v ( last_v ); 275 sum /= beta * beta; 251 276 sum++; 252 277 253 v/=beta; 254 v(last_v)=1; 255 256 pom=-2.0/sum; 257 // echo to venca 258 259 for ( j=i;j>=0;j-- ) 260 { 261 double w_elem = 0; 262 for ( ii=n- cc;ii<=m+i+1;ii++ ) 263 w_elem+= v( ii-n+cc )*L( ii-1,j ); 264 w(j)=w_elem*pom; 265 } 266 267 for ( ii=n-cc-1;ii<=m+i;ii++ ) 268 for ( jj=0;jj<i;jj++ ) 269 L( ii,jj )+= v( ii-n+cc+1)*w( jj ); 270 271 for ( ii=n-cc-1;ii<m+i;ii++ ) 272 L( ii,i )= 0; 273 274 L( m+i,i )+=w( i ); 275 D( m+i )=L( m+i,i )*L( m+i,i ); 276 277 for ( ii=0;ii<=i;ii++ ) 278 L( m+i,ii )/=L( m+i,i ); 279 } 280 281 if ( i>=0 ) 282 for ( ii=0;ii<i;ii++ ) 283 { 284 jj = D.length()-1-n+ii; 285 D(jj) = 0; 286 L.set_row(jj,zeros(L.cols())); //TODO: set_row accepts Num_T 287 L(jj,jj)=1; 288 } 289 290 L.del_rows(0,m-1); 291 D.del(0,m-1); 292 278 v /= beta; 279 v ( last_v ) = 1; 280 281 pom = -2.0 / sum; 282 // echo to venca 283 284 for ( j = i; j >= 0; j-- ) { 285 double w_elem = 0; 286 for ( ii = n - cc; ii <= m + i + 1; ii++ ) 287 w_elem += v ( ii - n + cc ) * L ( ii - 1, j ); 288 w ( j ) = w_elem * pom; 289 } 290 291 for ( ii = n - cc - 1; ii <= m + i; ii++ ) 292 for ( jj = 0; jj < i; jj++ ) 293 L ( ii, jj ) += v ( ii - n + cc + 1 ) * w ( jj ); 294 295 for ( ii = n - cc - 1; ii < m + i; ii++ ) 296 L ( ii, i ) = 0; 297 298 L ( m + i, i ) += w ( i ); 299 D ( m + i ) = L ( m + i, i ) * L ( m + i, i ); 300 301 for ( ii = 0; ii <= i; ii++ ) 302 L ( m + i, ii ) /= L ( m + i, i ); 303 } 304 305 if ( i >= 0 ) 306 for ( ii = 0; ii < i; ii++ ) { 307 jj = D.length() - 1 - n + ii; 308 D ( jj ) = 0; 309 L.set_row ( jj, zeros ( L.cols() ) ); //TODO: set_row accepts Num_T 310 L ( jj, jj ) = 1; 311 } 312 313 L.del_rows ( 0, m - 1 ); 314 D.del ( 0, m - 1 ); 315 293 316 dim = L.rows(); 294 317 } … … 296 319 //////// Auxiliary Functions 297 320 298 mat ltuinv ( const mat &L ) {321 mat ltuinv ( const mat &L ) { 299 322 int dim = L.cols(); 300 mat Il = eye ( dim );323 mat Il = eye ( dim ); 301 324 int i, j, k, m; 302 325 double s; 303 326 304 327 //Fixme blind transcription of ltuinv.m 305 for ( k = 1; k < ( dim ); k++ ) {306 for ( i = 0; i < ( dim - k ); i++ ) {328 for ( k = 1; k < ( dim ); k++ ) { 329 for ( i = 0; i < ( dim - k ); i++ ) { 307 330 j = i + k; //change in .m 1+1=2, here 0+0+1=1 308 s = L ( j, i );331 s = L ( j, i ); 309 332 for ( m = i + 1; m < ( j ); m++ ) { 310 s += L ( m, i ) * Il( j, m );333 s += L ( m, i ) * Il ( j, m ); 311 334 } 312 Il ( j, i ) = -s;335 Il ( j, i ) = -s; 313 336 } 314 337 } … … 317 340 } 318 341 319 void dydr ( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )342 void dydr ( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx ) 320 343 /******************************************************************** 321 344 … … 353 376 double threshold = 1e-4; 354 377 355 if ( fabs ( *Dr ) < mzero ) *Dr = 0;378 if ( fabs ( *Dr ) < mzero ) *Dr = 0; 356 379 r0 = *R; 357 380 *R = 0.0; … … 366 389 *kr = 0.0; 367 390 if ( *Df < -threshold ) { 368 it_warning ( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );391 it_warning ( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." ); 369 392 } 370 393 *Df = 0.0;