root/library/bdm/base/itpp/signal/fastica.cpp @ 1407

Revision 1407, 27.5 kB (checked in by sindj, 13 years ago)

Pridany soucasti IT++. signal pro signal processing a stat pro statistiku. JS

Line 
1/*!
2 * \file
3 * \brief Implementation of FastICA (Independent Component Analysis) for IT++
4 * \author Francois Cayre and Teddy Furon
5 *
6 * -------------------------------------------------------------------------
7 *
8 * Copyright (C) 1995-2010  (see AUTHORS file for a list of contributors)
9 *
10 * This file is part of IT++ - a C++ library of mathematical, signal
11 * processing, speech processing, and communications classes and functions.
12 *
13 * IT++ is free software: you can redistribute it and/or modify it under the
14 * terms of the GNU General Public License as published by the Free Software
15 * Foundation, either version 3 of the License, or (at your option) any
16 * later version.
17 *
18 * IT++ is distributed in the hope that it will be useful, but WITHOUT ANY
19 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
20 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
21 * details.
22 *
23 * You should have received a copy of the GNU General Public License along
24 * with IT++.  If not, see <http://www.gnu.org/licenses/>.
25 *
26 * -------------------------------------------------------------------------
27 *
28 * This is IT++ implementation of the original Matlab package FastICA.
29 *
30 * This code is Copyright (C) 2004 by:
31 *   Francois CAYRE and Teddy FURON
32 *   TEMICS Project
33 *   INRIA/Rennes (IRISA)
34 *   Campus Universitaire de Beaulieu
35 *   35042 RENNES cedex FRANCE
36 *
37 * Email : firstname.lastname@irisa.fr
38 *
39 * Matlab package is Copyright (C) 1998 by:
40 *   Jarmo HURRI, Hugo GAVERT, Jaakko SARELA and Aapo HYVARINEN
41 *   Laboratory of Information and Computer Science
42 *   Helsinki University of Technology              *
43 *
44 * URL : http://www.cis.hut.fi/projects/ica/fastica/about.shtml
45 *
46 * If you use results given by this FastICA software in an article for
47 * a scientific journal, conference proceedings or similar, please
48 * include the following original reference in the bibliography :
49 *
50 *   A. Hyvarinen, Fast and Robust Fixed-Point Algorithms for
51 *   Independent Component Analysis, IEEE Transactions on Neural
52 *   Networks 10(3):626-634, 1999
53 *
54 * Differences with the original Matlab implementation:
55 * - no GUI
56 * - return something even in the case of a convergence problem
57 * - optimization of SVD decomposition (performed 2 times in Matlab,
58 *   only 1 time in IT++)
59 * - default approach is SYMM with non-linearity POW3
60 */
61
62#include <itpp/signal/fastica.h>
63#include <itpp/signal/sigfun.h>
64#include <itpp/signal/resampling.h>
65#include <itpp/base/algebra/eigen.h>
66#include <itpp/base/algebra/svd.h>
67#include <itpp/base/math/trig_hyp.h>
68#include <itpp/base/matfunc.h>
69#include <itpp/base/random.h>
70#include <itpp/base/sort.h>
71#include <itpp/base/specmat.h>
72#include <itpp/base/svec.h>
73#include <itpp/base/math/min_max.h>
74#include <itpp/stat/misc_stat.h>
75
76
77using namespace itpp;
78
79
80/*!
81  \brief Local functions for FastICA
82  @{
83*/
84static void selcol(const mat oldMatrix, const vec maskVector, mat & newMatrix);
85static void pcamat(const mat vectors, const int numOfIC, int firstEig, int lastEig, mat & Es, vec & Ds);
86static void remmean(mat inVectors, mat & outVectors, vec & meanValue);
87static void whitenv(const mat vectors, const mat E, const mat D, mat & newVectors, mat & whiteningMatrix, mat & dewhiteningMatrix);
88static mat orth(const mat A);
89static mat mpower(const mat A, const double y);
90static ivec getSamples(const int max, const double percentage);
91static vec sumcol(const mat A);
92static void fpica(const mat X, const mat whiteningMatrix, const mat dewhiteningMatrix, const int approach, const int numOfIC, const int g, const int finetune, const double a1, const double a2, double myy, const int stabilization, const double epsilon, const int maxNumIterations, const int maxFinetune, const int initState, mat guess, double sampleSize, mat & A, mat & W);
93/*! @} */
94
95namespace itpp
96{
97
98// Constructor, init default values
99Fast_ICA::Fast_ICA(mat ma_mixedSig)
100{
101
102  // Init default values
103  approach = FICA_APPROACH_SYMM;
104  g = FICA_NONLIN_POW3;
105  finetune = true;
106  a1 = 1.0;
107  a2 = 1.0;
108  mu = 1.0;
109  epsilon = 0.0001;
110  sampleSize = 1.0;
111  stabilization = false;
112  maxNumIterations = 100000;
113  maxFineTune = 100;
114  firstEig = 1;
115
116  mixedSig = ma_mixedSig;
117
118  lastEig = mixedSig.rows();
119  numOfIC = mixedSig.rows();
120  PCAonly = false;
121  initState = FICA_INIT_RAND;
122
123}
124
125// Call main function
126void Fast_ICA::separate(void)
127{
128
129  int Dim = numOfIC;
130
131  mat mixedSigC;
132  vec mixedMean;
133
134  mat guess;
135  if (initState == FICA_INIT_RAND)
136    guess = zeros(Dim, Dim);
137  else
138    guess = mat(initGuess);
139
140  VecPr = zeros(mixedSig.rows(), numOfIC);
141
142  icasig = zeros(numOfIC, mixedSig.cols());
143
144  remmean(mixedSig, mixedSigC, mixedMean);
145
146  pcamat(mixedSigC, numOfIC, firstEig, lastEig, E, D);
147
148  whitenv(mixedSigC, E, diag(D), whitesig, whiteningMatrix, dewhiteningMatrix);
149
150
151  ivec NcFirst = to_ivec(zeros(numOfIC));
152  vec NcVp = D;
153  for (int i = 0; i < NcFirst.size(); i++) {
154
155    NcFirst(i) = max_index(NcVp);
156    NcVp(NcFirst(i)) = 0.0;
157    VecPr.set_col(i, dewhiteningMatrix.get_col(i));
158
159  }
160
161  if (PCAonly == false) {
162
163    Dim = whitesig.rows();
164
165    if (numOfIC > Dim) numOfIC = Dim;
166
167    fpica(whitesig, whiteningMatrix, dewhiteningMatrix, approach, numOfIC, g, finetune, a1, a2, mu, stabilization, epsilon, maxNumIterations, maxFineTune, initState, guess, sampleSize, A, W);
168
169    icasig = W * mixedSig;
170
171  }
172
173  else { // PCA only : returns E as IcaSig
174    icasig = VecPr;
175  }
176}
177
178void Fast_ICA::set_approach(int in_approach) { approach = in_approach; if (approach == FICA_APPROACH_DEFL) finetune = true; }
179
180void Fast_ICA::set_nrof_independent_components(int in_nrIC) { numOfIC = in_nrIC; }
181
182void Fast_ICA::set_non_linearity(int in_g) { g = in_g; }
183
184void Fast_ICA::set_fine_tune(bool in_finetune) { finetune = in_finetune; }
185
186void Fast_ICA::set_a1(double fl_a1) { a1 = fl_a1; }
187
188void Fast_ICA::set_a2(double fl_a2) { a2 = fl_a2; }
189
190void Fast_ICA::set_mu(double fl_mu) { mu = fl_mu; }
191
192void Fast_ICA::set_epsilon(double fl_epsilon) { epsilon = fl_epsilon; }
193
194void Fast_ICA::set_sample_size(double fl_sampleSize) { sampleSize = fl_sampleSize; }
195
196void Fast_ICA::set_stabilization(bool in_stabilization) { stabilization = in_stabilization; }
197
198void Fast_ICA::set_max_num_iterations(int in_maxNumIterations) { maxNumIterations = in_maxNumIterations; }
199
200void Fast_ICA::set_max_fine_tune(int in_maxFineTune) { maxFineTune = in_maxFineTune; }
201
202void Fast_ICA::set_first_eig(int in_firstEig) { firstEig = in_firstEig; }
203
204void Fast_ICA::set_last_eig(int in_lastEig) { lastEig = in_lastEig; }
205
206void Fast_ICA::set_pca_only(bool in_PCAonly) { PCAonly = in_PCAonly; }
207
208void Fast_ICA::set_init_guess(mat ma_initGuess)
209{
210  initGuess = ma_initGuess;
211  initState = FICA_INIT_GUESS;
212}
213
214mat Fast_ICA::get_mixing_matrix() { if (PCAonly) { it_warning("No ICA performed."); return (zeros(1, 1));} else return A; }
215
216mat Fast_ICA::get_separating_matrix() { if (PCAonly) { it_warning("No ICA performed."); return(zeros(1, 1)); } else return W; }
217
218mat Fast_ICA::get_independent_components() { if (PCAonly) { it_warning("No ICA performed."); return(zeros(1, 1)); } else return icasig; }
219
220int Fast_ICA::get_nrof_independent_components() { return numOfIC; }
221
222mat Fast_ICA::get_principal_eigenvectors() { return VecPr; }
223
224mat Fast_ICA::get_whitening_matrix() { return whiteningMatrix; }
225
226mat Fast_ICA::get_dewhitening_matrix() { return dewhiteningMatrix; }
227
228mat Fast_ICA::get_white_sig() { return whitesig; }
229
230} // namespace itpp
231
232
233static void selcol(const mat oldMatrix, const vec maskVector, mat & newMatrix)
234{
235
236  int numTaken = 0;
237
238  for (int i = 0; i < size(maskVector); i++) if (maskVector(i) == 1) numTaken++;
239
240  newMatrix = zeros(oldMatrix.rows(), numTaken);
241
242  numTaken = 0;
243
244  for (int i = 0; i < size(maskVector); i++) {
245
246    if (maskVector(i) == 1) {
247
248      newMatrix.set_col(numTaken, oldMatrix.get_col(i));
249      numTaken++;
250
251    }
252  }
253
254  return;
255
256}
257
258static void pcamat(const mat vectors, const int numOfIC, int firstEig, int lastEig, mat & Es, vec & Ds)
259{
260
261  mat Et;
262  vec Dt;
263  cmat Ec;
264  cvec Dc;
265  double lowerLimitValue = 0.0,
266                           higherLimitValue = 0.0;
267
268  int oldDimension = vectors.rows();
269
270  mat covarianceMatrix = cov(transpose(vectors), 0);
271
272  eig_sym(covarianceMatrix, Dt, Et);
273
274  int maxLastEig = 0;
275
276  // Compute rank
277  for (int i = 0; i < Dt.length(); i++) if (Dt(i) > FICA_TOL) maxLastEig++;
278
279  // Force numOfIC components
280  if (maxLastEig > numOfIC) maxLastEig = numOfIC;
281
282  vec eigenvalues = zeros(size(Dt));
283  vec eigenvalues2 = zeros(size(Dt));
284
285  eigenvalues2 = Dt;
286
287  sort(eigenvalues2);
288
289  vec lowerColumns = zeros(size(Dt));
290
291  for (int i = 0; i < size(Dt); i++) eigenvalues(i) = eigenvalues2(size(Dt) - i - 1);
292
293  if (lastEig > maxLastEig) lastEig = maxLastEig;
294
295  if (lastEig < oldDimension) lowerLimitValue = (eigenvalues(lastEig - 1) + eigenvalues(lastEig)) / 2;
296  else lowerLimitValue = eigenvalues(oldDimension - 1) - 1;
297
298  for (int i = 0; i < size(Dt); i++) if (Dt(i) > lowerLimitValue) lowerColumns(i) = 1;
299
300  if (firstEig > 1) higherLimitValue = (eigenvalues(firstEig - 2) + eigenvalues(firstEig - 1)) / 2;
301  else higherLimitValue = eigenvalues(0) + 1;
302
303  vec higherColumns = zeros(size(Dt));
304  for (int i = 0; i < size(Dt); i++) if (Dt(i) < higherLimitValue) higherColumns(i) = 1;
305
306  vec selectedColumns = zeros(size(Dt));
307  for (int i = 0; i < size(Dt); i++) selectedColumns(i) = (lowerColumns(i) == 1 && higherColumns(i) == 1) ? 1 : 0;
308
309  selcol(Et, selectedColumns, Es);
310
311  int numTaken = 0;
312
313  for (int i = 0; i < size(selectedColumns); i++) if (selectedColumns(i) == 1) numTaken++;
314
315  Ds = zeros(numTaken);
316
317  numTaken = 0;
318
319  for (int i = 0; i < size(Dt); i++)
320    if (selectedColumns(i) == 1) {
321      Ds(numTaken) = Dt(i);
322      numTaken++;
323    }
324
325  return;
326
327}
328
329
330static void remmean(mat inVectors, mat & outVectors, vec & meanValue)
331{
332
333  outVectors = zeros(inVectors.rows(), inVectors.cols());
334  meanValue = zeros(inVectors.rows());
335
336  for (int i = 0; i < inVectors.rows(); i++) {
337
338    meanValue(i) = mean(inVectors.get_row(i));
339
340    for (int j = 0; j < inVectors.cols(); j++) outVectors(i, j) = inVectors(i, j) - meanValue(i);
341
342  }
343
344}
345
346static void whitenv(const mat vectors, const mat E, const mat D, mat & newVectors, mat & whiteningMatrix, mat & dewhiteningMatrix)
347{
348
349  whiteningMatrix = zeros(E.cols(), E.rows());
350  dewhiteningMatrix = zeros(E.rows(), E.cols());
351
352  for (int i = 0; i < D.cols(); i++) {
353    whiteningMatrix.set_row(i, std::pow(std::sqrt(D(i, i)), -1)*E.get_col(i));
354    dewhiteningMatrix.set_col(i, std::sqrt(D(i, i))*E.get_col(i));
355  }
356
357  newVectors = whiteningMatrix * vectors;
358
359  return;
360
361}
362
363static mat orth(const mat A)
364{
365
366  mat Q;
367  mat U, V;
368  vec S;
369  double eps = 2.2e-16;
370  double tol = 0.0;
371  int mmax = 0;
372  int r = 0;
373
374  svd(A, U, S, V);
375  if (A.rows() > A.cols()) {
376
377    U = U(0, U.rows() - 1, 0, A.cols() - 1);
378    S = S(0, A.cols() - 1);
379  }
380
381  mmax = (A.rows() > A.cols()) ? A.rows() : A.cols();
382
383  tol = mmax * eps * max(S);
384
385  for (int i = 0; i < size(S); i++) if (S(i) > tol) r++;
386
387  Q = U(0, U.rows() - 1, 0, r - 1);
388
389  return (Q);
390}
391
392static mat mpower(const mat A, const double y)
393{
394
395  mat T = zeros(A.rows(), A.cols());
396  mat dd = zeros(A.rows(), A.cols());
397  vec d = zeros(A.rows());
398  vec dOut = zeros(A.rows());
399
400  eig_sym(A, d, T);
401
402  dOut = pow(d, y);
403
404  diag(dOut, dd);
405
406  for (int i = 0; i < T.cols(); i++) T.set_col(i, T.get_col(i) / norm(T.get_col(i)));
407
408  return (T*dd*transpose(T));
409
410}
411
412static ivec getSamples(const int max, const double percentage)
413{
414
415  vec rd = randu(max);
416  sparse_vec sV;
417  ivec out;
418  int sZ = 0;
419
420  for (int i = 0; i < max; i++) if (rd(i) < percentage) { sV.add_elem(sZ, i); sZ++; }
421
422  out = to_ivec(full(sV));
423
424  return (out);
425
426}
427
428static vec sumcol(const mat A)
429{
430
431  vec out = zeros(A.cols());
432
433  for (int i = 0; i < A.cols(); i++) { out(i) = sum(A.get_col(i)); }
434
435  return (out);
436
437}
438
439static void fpica(const mat X, const mat whiteningMatrix, const mat dewhiteningMatrix, const int approach, const int numOfIC, const int g, const int finetune, const double a1, const double a2, double myy, const int stabilization, const double epsilon, const int maxNumIterations, const int maxFinetune, const int initState, mat guess, double sampleSize, mat & A, mat & W)
440{
441
442  int vectorSize = X.rows();
443  int numSamples = X.cols();
444  int gOrig = g;
445  int gFine = finetune + 1;
446  double myyOrig = myy;
447  double myyK = 0.01;
448  int failureLimit = 5;
449  int usedNlinearity = 0;
450  double stroke = 0.0;
451  int notFine = 1;
452  int loong = 0;
453  int initialStateMode = initState;
454  double minAbsCos = 0.0, minAbsCos2 = 0.0;
455
456  if (sampleSize * numSamples < 1000) sampleSize = (1000 / (double)numSamples < 1.0) ? 1000 / (double)numSamples : 1.0;
457
458  if (sampleSize != 1.0) gOrig += 2;
459  if (myy != 1.0) gOrig += 1;
460
461  int fineTuningEnabled = 1;
462
463  if (!finetune) {
464    if (myy != 1.0) gFine = gOrig;
465    else gFine = gOrig + 1;
466    fineTuningEnabled = 0;
467  }
468
469  int stabilizationEnabled = stabilization;
470
471  if (!stabilization && myy != 1.0) stabilizationEnabled = true;
472
473  usedNlinearity = gOrig;
474
475  if (initState == FICA_INIT_GUESS && guess.rows() != whiteningMatrix.cols()) {
476    initialStateMode = 0;
477
478  }
479  else if (guess.cols() < numOfIC) {
480
481    mat guess2 = randu(guess.rows(), numOfIC - guess.cols()) - 0.5;
482    guess = concat_horizontal(guess, guess2);
483  }
484  else if (guess.cols() > numOfIC) guess = guess(0, guess.rows() - 1, 0, numOfIC - 1);
485
486  if (approach == FICA_APPROACH_SYMM) {
487
488    usedNlinearity = gOrig;
489    stroke = 0;
490    notFine = 1;
491    loong = 0;
492
493    A = zeros(vectorSize, numOfIC);
494    mat B = zeros(vectorSize, numOfIC);
495
496    if (initialStateMode == 0) B = orth(randu(vectorSize, numOfIC) - 0.5);
497    else B = whiteningMatrix * guess;
498
499    mat BOld = zeros(B.rows(), B.cols());
500    mat BOld2 = zeros(B.rows(), B.cols());
501
502    for (int round = 0; round < maxNumIterations; round++) {
503
504      if (round == maxNumIterations - 1) {
505
506        // If there is a convergence problem,
507        // we still want ot return something.
508        // This is difference with original
509        // Matlab implementation.
510        A = dewhiteningMatrix * B;
511        W = transpose(B) * whiteningMatrix;
512
513        return;
514      }
515
516      B = B * mpower(transpose(B) * B , -0.5);
517
518      minAbsCos = min(abs(diag(transpose(B) * BOld)));
519      minAbsCos2 = min(abs(diag(transpose(B) * BOld2)));
520
521      if (1 - minAbsCos < epsilon) {
522
523        if (fineTuningEnabled && notFine) {
524
525          notFine = 0;
526          usedNlinearity = gFine;
527          myy = myyK * myyOrig;
528          BOld = zeros(B.rows(), B.cols());
529          BOld2 = zeros(B.rows(), B.cols());
530
531        }
532
533        else {
534
535          A = dewhiteningMatrix * B;
536          break;
537
538        }
539
540      } // IF epsilon
541
542      else if (stabilizationEnabled) {
543        if (!stroke && (1 - minAbsCos2 < epsilon)) {
544
545          stroke = myy;
546          myy /= 2;
547          if (mod(usedNlinearity, 2) == 0) usedNlinearity += 1 ;
548
549        }
550        else if (stroke) {
551
552          myy = stroke;
553          stroke = 0;
554          if (myy == 1 && mod(usedNlinearity, 2) != 0) usedNlinearity -= 1;
555
556        }
557        else if (!loong && (round > maxNumIterations / 2)) {
558
559          loong = 1;
560          myy /= 2;
561          if (mod(usedNlinearity, 2) == 0) usedNlinearity += 1;
562
563        }
564
565      } // stabilizationEnabled
566
567      BOld2 = BOld;
568      BOld = B;
569
570      switch (usedNlinearity) {
571
572        // pow3
573      case FICA_NONLIN_POW3 : {
574        B = (X * pow(transpose(X) * B, 3)) / numSamples - 3 * B;
575        break;
576      }
577      case(FICA_NONLIN_POW3+1) : {
578        mat Y = transpose(X) * B;
579        mat Gpow3 = pow(Y, 3);
580        vec Beta = sumcol(pow(Y, 4));
581        mat D = diag(pow(Beta - 3 * numSamples , -1));
582        B = B + myy * B * (transpose(Y) * Gpow3 - diag(Beta)) * D;
583        break;
584      }
585      case(FICA_NONLIN_POW3+2) : {
586        mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
587        B = (Xsub * pow(transpose(Xsub) * B, 3)) / Xsub.cols() - 3 * B;
588        break;
589      }
590      case(FICA_NONLIN_POW3+3) : {
591        mat Ysub = transpose(X.get_cols(getSamples(numSamples, sampleSize))) * B;
592        mat Gpow3 = pow(Ysub, 3);
593        vec Beta = sumcol(pow(Ysub, 4));
594        mat D = diag(pow(Beta - 3 * Ysub.rows() , -1));
595        B = B + myy * B * (transpose(Ysub) * Gpow3 - diag(Beta)) * D;
596        break;
597      }
598
599      // TANH
600      case FICA_NONLIN_TANH : {
601        mat hypTan = tanh(a1 * transpose(X) * B);
602        B = (X * hypTan) / numSamples - elem_mult(reshape(repeat(sumcol(1 - pow(hypTan, 2)), B.rows()), B.rows(), B.cols()), B) / numSamples * a1;
603        break;
604      }
605      case(FICA_NONLIN_TANH+1) : {
606        mat Y = transpose(X) * B;
607        mat hypTan = tanh(a1 * Y);
608        vec Beta = sumcol(elem_mult(Y, hypTan));
609        vec Beta2 = sumcol(1 - pow(hypTan, 2));
610        mat D = diag(pow(Beta - a1 * Beta2 , -1));
611        B = B + myy * B * (transpose(Y) * hypTan - diag(Beta)) * D;
612        break;
613      }
614      case(FICA_NONLIN_TANH+2) : {
615        mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
616        mat hypTan = tanh(a1 * transpose(Xsub) * B);
617        B = (Xsub * hypTan) / Xsub.cols() -  elem_mult(reshape(repeat(sumcol(1 - pow(hypTan, 2)), B.rows()), B.rows(), B.cols()), B) / Xsub.cols() * a1;
618        break;
619      }
620      case(FICA_NONLIN_TANH+3) : {
621        mat Ysub = transpose(X.get_cols(getSamples(numSamples, sampleSize))) * B;
622        mat hypTan = tanh(a1 * Ysub);
623        vec Beta = sumcol(elem_mult(Ysub, hypTan));
624        vec Beta2 = sumcol(1 - pow(hypTan, 2));
625        mat D = diag(pow(Beta - a1 * Beta2 , -1));
626        B = B + myy * B * (transpose(Ysub) * hypTan - diag(Beta)) * D;
627        break;
628      }
629
630      // GAUSS
631      case FICA_NONLIN_GAUSS : {
632        mat U = transpose(X) * B;
633        mat Usquared = pow(U, 2);
634        mat ex = exp(-a2 * Usquared / 2);
635        mat gauss = elem_mult(U, ex);
636        mat dGauss = elem_mult(1 - a2 * Usquared, ex);
637        B = (X * gauss) / numSamples - elem_mult(reshape(repeat(sumcol(dGauss), B.rows()), B.rows(), B.cols()), B) / numSamples;
638        break;
639      }
640      case(FICA_NONLIN_GAUSS+1) : {
641        mat Y = transpose(X) * B;
642        mat ex = exp(-a2 * pow(Y, 2) / 2);
643        mat gauss = elem_mult(Y, ex);
644        vec Beta = sumcol(elem_mult(Y, gauss));
645        vec Beta2 = sumcol(elem_mult(1 - a2 * pow(Y, 2), ex));
646        mat D = diag(pow(Beta - Beta2 , -1));
647        B = B + myy * B * (transpose(Y) * gauss - diag(Beta)) * D;
648        break;
649      }
650      case(FICA_NONLIN_GAUSS+2) : {
651        mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
652        mat U = transpose(Xsub) * B;
653        mat Usquared = pow(U, 2);
654        mat ex = exp(-a2 * Usquared / 2);
655        mat gauss = elem_mult(U, ex);
656        mat dGauss = elem_mult(1 - a2 * Usquared, ex);
657        B = (Xsub * gauss) / Xsub.cols() - elem_mult(reshape(repeat(sumcol(dGauss), B.rows()), B.rows(), B.cols()), B) / Xsub.cols();
658        break;
659      }
660      case(FICA_NONLIN_GAUSS+3) : {
661        mat Ysub = transpose(X.get_cols(getSamples(numSamples, sampleSize))) * B;
662        mat ex = exp(-a2 * pow(Ysub, 2) / 2);
663        mat gauss = elem_mult(Ysub, ex);
664        vec Beta = sumcol(elem_mult(Ysub, gauss));
665        vec Beta2 = sumcol(elem_mult(1 - a2 * pow(Ysub, 2), ex));
666        mat D = diag(pow(Beta - Beta2 , -1));
667        B = B + myy * B * (transpose(Ysub) * gauss - diag(Beta)) * D;
668        break;
669      }
670
671      // SKEW
672      case FICA_NONLIN_SKEW : {
673        B = (X * (pow(transpose(X) * B, 2))) / numSamples;
674        break;
675      }
676      case(FICA_NONLIN_SKEW+1) : {
677        mat Y = transpose(X) * B;
678        mat Gskew = pow(Y, 2);
679        vec Beta = sumcol(elem_mult(Y, Gskew));
680        mat D = diag(pow(Beta , -1));
681        B = B + myy * B * (transpose(Y) * Gskew - diag(Beta)) * D;
682        break;
683      }
684      case(FICA_NONLIN_SKEW+2) : {
685        mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
686        B = (Xsub * (pow(transpose(Xsub) * B, 2))) / Xsub.cols();
687        break;
688      }
689      case(FICA_NONLIN_SKEW+3) : {
690        mat Ysub = transpose(X.get_cols(getSamples(numSamples, sampleSize))) * B;
691        mat Gskew = pow(Ysub, 2);
692        vec Beta = sumcol(elem_mult(Ysub, Gskew));
693        mat D = diag(pow(Beta , -1));
694        B = B + myy * B * (transpose(Ysub) * Gskew - diag(Beta)) * D;
695        break;
696      }
697
698
699      } // SWITCH usedNlinearity
700
701    } // FOR maxIterations
702
703    W = transpose(B) * whiteningMatrix;
704
705
706  } // IF FICA_APPROACH_SYMM APPROACH
707
708  // DEFLATION
709  else {
710
711    // FC 01/12/05
712    A = zeros(whiteningMatrix.cols(), numOfIC);
713    //    A = zeros( vectorSize, numOfIC );
714    mat B = zeros(vectorSize, numOfIC);
715    W = transpose(B) * whiteningMatrix;
716    int round = 1;
717    int numFailures = 0;
718
719    while (round <= numOfIC) {
720
721      myy = myyOrig;
722
723      usedNlinearity = gOrig;
724      stroke = 0;
725
726      notFine = 1;
727      loong = 0;
728      int endFinetuning = 0;
729
730      vec w = zeros(vectorSize);
731
732      if (initialStateMode == 0)
733
734        w = randu(vectorSize) - 0.5;
735
736      else w = whiteningMatrix * guess.get_col(round);
737
738      w = w - B * transpose(B) * w;
739
740      w /= norm(w);
741
742      vec wOld = zeros(vectorSize);
743      vec wOld2 = zeros(vectorSize);
744
745      int i = 1;
746      int gabba = 1;
747
748      while (i <= maxNumIterations + gabba) {
749
750        w = w - B * transpose(B) * w;
751
752        w /= norm(w);
753
754        if (notFine) {
755
756          if (i == maxNumIterations + 1) {
757
758            round--;
759
760            numFailures++;
761
762            if (numFailures > failureLimit) {
763
764              if (round == 0) {
765
766                A = dewhiteningMatrix * B;
767                W = transpose(B) * whiteningMatrix;
768
769              } // IF round
770
771              break;
772
773            } // IF numFailures > failureLimit
774
775            break;
776
777          } // IF i == maxNumIterations+1
778
779        } // IF NOTFINE
780
781        else if (i >= endFinetuning) wOld = w;
782
783        if (norm(w - wOld) < epsilon || norm(w + wOld) < epsilon) {
784
785          if (fineTuningEnabled && notFine) {
786
787            notFine = 0;
788            gabba = maxFinetune;
789            wOld = zeros(vectorSize);
790            wOld2 = zeros(vectorSize);
791            usedNlinearity = gFine;
792            myy = myyK * myyOrig;
793            endFinetuning = maxFinetune + i;
794
795          } // IF finetuning
796
797          else {
798
799            numFailures = 0;
800
801            B.set_col(round - 1, w);
802
803            A.set_col(round - 1, dewhiteningMatrix*w);
804
805            W.set_row(round - 1, transpose(whiteningMatrix)*w);
806
807            break;
808
809          } // ELSE finetuning
810
811        } // IF epsilon
812
813        else if (stabilizationEnabled) {
814
815          if (stroke == 0.0 && (norm(w - wOld2) < epsilon || norm(w + wOld2) < epsilon)) {
816
817            stroke = myy;
818            myy /= 2.0 ;
819
820            if (mod(usedNlinearity, 2) == 0) {
821
822              usedNlinearity++;
823
824            } // IF MOD
825
826          }// IF !stroke
827
828          else if (stroke != 0.0) {
829
830            myy = stroke;
831            stroke = 0.0;
832
833            if (myy == 1 && mod(usedNlinearity, 2) != 0) {
834              usedNlinearity--;
835            }
836
837          } // IF Stroke
838
839          else if (notFine && !loong && i > maxNumIterations / 2) {
840
841            loong = 1;
842            myy /= 2.0;
843
844            if (mod(usedNlinearity, 2) == 0) {
845
846              usedNlinearity++;
847
848            } // IF MOD
849
850          } // IF notFine
851
852        } // IF stabilization
853
854
855        wOld2 = wOld;
856        wOld = w;
857
858        switch (usedNlinearity) {
859
860          // pow3
861        case FICA_NONLIN_POW3 : {
862          w = (X * pow(transpose(X) * w, 3)) / numSamples - 3 * w;
863          break;
864        }
865        case(FICA_NONLIN_POW3+1) : {
866          vec Y = transpose(X) * w;
867          vec Gpow3 = X * pow(Y, 3) / numSamples;
868          double Beta = dot(w, Gpow3);
869          w = w - myy * (Gpow3 - Beta * w) / (3 - Beta);
870          break;
871        }
872        case(FICA_NONLIN_POW3+2) : {
873          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
874          w = (Xsub * pow(transpose(Xsub) * w, 3)) / Xsub.cols() - 3 * w;
875          break;
876        }
877        case(FICA_NONLIN_POW3+3) : {
878          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
879          vec Gpow3 = Xsub * pow(transpose(Xsub) * w, 3) / (Xsub.cols());
880          double Beta = dot(w, Gpow3);
881          w = w - myy * (Gpow3 - Beta * w) / (3 - Beta);
882          break;
883        }
884
885        // TANH
886        case FICA_NONLIN_TANH : {
887          vec hypTan = tanh(a1 * transpose(X) * w);
888          w = (X * hypTan - a1 * sum(1 - pow(hypTan, 2)) * w) / numSamples;
889          break;
890        }
891        case(FICA_NONLIN_TANH+1) : {
892          vec Y = transpose(X) * w;
893          vec hypTan = tanh(a1 * Y);
894          double Beta = dot(w, X * hypTan);
895          w = w - myy * ((X * hypTan - Beta * w) / (a1 * sum(1 - pow(hypTan, 2)) - Beta));
896          break;
897        }
898        case(FICA_NONLIN_TANH+2) : {
899          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
900          vec hypTan = tanh(a1 * transpose(Xsub) * w);
901          w = (Xsub * hypTan - a1 * sum(1 - pow(hypTan, 2)) * w) / Xsub.cols();
902          break;
903        }
904        case(FICA_NONLIN_TANH+3) : {
905          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
906          vec hypTan = tanh(a1 * transpose(Xsub) * w);
907          double Beta = dot(w, Xsub * hypTan);
908          w = w - myy * ((Xsub * hypTan - Beta * w) / (a1 * sum(1 - pow(hypTan, 2)) - Beta));
909          break;
910        }
911
912        // GAUSS
913        case FICA_NONLIN_GAUSS : {
914          vec u = transpose(X) * w;
915          vec Usquared = pow(u, 2);
916          vec ex = exp(-a2 * Usquared / 2);
917          vec gauss = elem_mult(u, ex);
918          vec dGauss = elem_mult(1 - a2 * Usquared, ex);
919          w = (X * gauss - sum(dGauss) * w) / numSamples;
920          break;
921        }
922        case(FICA_NONLIN_GAUSS+1) : {
923          vec u = transpose(X) * w;
924          vec Usquared = pow(u, 2);
925
926          vec ex = exp(-a2 * Usquared / 2);
927          vec gauss = elem_mult(u, ex);
928          vec dGauss = elem_mult(1 - a2 * Usquared, ex);
929          double Beta = dot(w, X * gauss);
930          w = w - myy * ((X * gauss - Beta * w) / (sum(dGauss) - Beta));
931          break;
932        }
933        case(FICA_NONLIN_GAUSS+2) : {
934          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
935          vec u = transpose(Xsub) * w;
936          vec Usquared = pow(u, 2);
937          vec ex = exp(-a2 * Usquared / 2);
938          vec gauss = elem_mult(u, ex);
939          vec dGauss = elem_mult(1 - a2 * Usquared, ex);
940          w = (Xsub * gauss - sum(dGauss) * w) / Xsub.cols();
941          break;
942        }
943        case(FICA_NONLIN_GAUSS+3) : {
944          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
945          vec u = transpose(Xsub) * w;
946          vec Usquared = pow(u, 2);
947          vec ex = exp(-a2 * Usquared / 2);
948          vec gauss = elem_mult(u, ex);
949          vec dGauss = elem_mult(1 - a2 * Usquared, ex);
950          double Beta = dot(w, Xsub * gauss);
951          w = w - myy * ((Xsub * gauss - Beta * w) / (sum(dGauss) - Beta));
952          break;
953        }
954
955        // SKEW
956        case FICA_NONLIN_SKEW : {
957          w = (X * (pow(transpose(X) * w, 2))) / numSamples;
958          break;
959        }
960        case(FICA_NONLIN_SKEW+1) : {
961          vec Y = transpose(X) * w;
962          vec Gskew = X * pow(Y, 2) / numSamples;
963          double Beta = dot(w, Gskew);
964          w = w - myy * (Gskew - Beta * w / (-Beta));
965          break;
966        }
967        case(FICA_NONLIN_SKEW+2) : {
968          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
969          w = (Xsub * (pow(transpose(Xsub) * w, 2))) / Xsub.cols();
970          break;
971        }
972        case(FICA_NONLIN_SKEW+3) : {
973          mat Xsub = X.get_cols(getSamples(numSamples, sampleSize));
974          vec Gskew = Xsub * pow(transpose(Xsub) * w, 2) / Xsub.cols();
975          double Beta = dot(w, Gskew);
976          w = w - myy * (Gskew - Beta * w) / (-Beta);
977          break;
978        }
979
980        } // SWITCH nonLinearity
981
982        w /= norm(w);
983        i++;
984
985      } // WHILE i<= maxNumIterations+gabba
986
987      round++;
988
989    } // While round <= numOfIC
990
991  } // ELSE Deflation
992
993} // FPICA
Note: See TracBrowser for help on using the browser.