root/library/utia_legacy/ticket_12/straux1.m @ 678

Revision 571, 14.6 kB (checked in by smidl, 15 years ago)

matlab test for #12

Line 
1function [strout, rgrsout, statistics] = straux1(L, d, nu, L0, d0, nu0, belief, nbest, max_nrep, lambda, order_k);
2% structure estimation based on LD decomposition
3%
4% This m/mex file is internally called by facstr, IT IS NOT TO BE CALLED
5% BY USER!! Documentation guiven for reference.
6%   
7%     
8% [strout, rgrsout, statistics] = straux1(L, d, nu, L0, d0, nu0, belief, nbest, max_nrep, lambda, order_k);
9%
10% L   : Actual LD decomposition based on data
11% d   : Actual LD decomposition based on data
12% nu  : Actual data amount
13% L0  : prior information
14% d0  : prior information
15% nu0 : prior data amount
16% belief: user's belief on maximum structure items
17%         (1 items must     be present, 2 items are probably     present
18%          4 items must not be present, 3 items are probably not present)
19%          2 and 3 is the same
20% nbest : how many "best" regressors are maintained
21% strout : structure estimated (of the regressor, richest is 2:length(d)
22% max_nrep  : maximal number of random starts in search for the best
23%             structure
24% lambda : stooping rule threshold
25% order_k : order of k
26%   
27% Design  : L. Tesar
28% Updated : Feb-Apr 2003
29% Project : post-ProDaCTool
30% References: (only local inline functions)
31%
32% Todo: in add_new, we need to implement structure comparison, instead of
33% loglikelihood comparison: ~any(logliks == new.loglik)
34   
35% randun seed stuff:
36%global SEED
37%SEED = randn('seed');
38
39% Argument's checking:
40if nargin<8;
41   if nargout>=2;
42      nbest = 2;
43   else
44      % If we don't need the second parameter it is better to avoid
45      % calculating it at all, because it is very costly (5x slowdown).
46      nbest = 1;
47   end;   
48end;
49
50if nargin< 6, error('Incorrect number of input parameters in straux1'); end;
51if nargin< 7, belief   = []; end;   % Don't belive anybody.
52if nargin< 9, max_nrep = 3; end;
53if nargin<10, lambda   = 0.75; end;
54if nargin<11, order_k  = 2; end;
55% Arguments were just checked.
56
57n_data = length(d);
58
59belief_out = find(belief==4)+1; % we are avoiding to put this into regressor
60belief_in  = find(belief==1)+1; % we are instantly keeping this in regressor
61
62full.d0  = d0;
63full.nu0 = nu0;
64full.L0  = L0;
65full.L   = L;
66full.d   = d;
67full.nu  = nu;
68full.strL = 1:n_data;                 % Current structure of L and d
69full.strRgr = 2:n_data;               % Structure elements currently inside regressor (after regressand)
70full.strMis = [];                     % structure elements, that are currently outside regressor (before regressand)
71full.posit1 = 1;                      % regressand position
72full.nbits  = floor(log2(bitmax))-1;  % number of bits available in double
73full.bitstr = str_bitset(zeros(1,floor(n_data/full.nbits)+1),full.strRgr,full.nbits);
74full.loglik = seloglik1(full);        % loglikelihood
75   
76% construct full and empty structure
77full = sestrremove(full,belief_out);
78empty = sestrremove(full,setdiff(full.strRgr,belief_in));
79
80% stopping rule calculation:
81local_max = [];
82muto = 0;
83
84% statistics:
85cputime0 = cputime;
86if nargout>=3;
87   mutos = zeros(1,max_nrep+2);
88   maxmutos = zeros(1,max_nrep+2);
89end;
90% ----------------------
91
92% For stopping-rule calculation
93%so       = 2^(n_data -1-length(belief_in)- length(belief_out)); % do we use this ?
94% ----------------------
95
96all_str = 1:n_data;
97
98global_best = full;
99
100% MAIN LOOP is here.
101for n_start = -1:max_nrep;
102   to = n_start+2;
103   
104   if n_start == -1;
105      % start from the full structure
106      last = full;
107   elseif n_start == 0;
108      % start from the empty structure
109      last = empty;     
110   else
111      % start from random structure
112      last_str = find([ 0 floor(2*randun(1,n_data-1))]); % this creates random vector consisting of indexes, and sorted
113      last = sestrremove(full,setdiff(all_str,[1 last_str empty.strRgr]));
114   end;
115
116   % DEBUGging print:
117   %fprintf('STRUCTURE generated            in loop %2i was %s\n', n_start, strPrintstr(last));
118   
119   % The loop is repeated until likelihood stops growing (break condition
120   % used at the end;
121   while 1;
122     
123      % This structure is going to hold the best elements
124      best = last;
125      % Nesting by removing elements (enpoorment)
126      for removed_item = setdiff(last.strRgr,belief_in);
127        new = sestrremove(last,removed_item);
128        if nbest>1;
129           global_best = add_new(global_best,new,nbest);
130        end;
131        if new.loglik>best.loglik;
132           best = new;
133        end;
134      end;
135      % Nesting by adding elements (enrichment)
136      for added_item = setdiff(last.strMis,belief_out);
137        new = sestrinsert(last,added_item);
138        if nbest>1;
139           global_best = add_new(global_best,new,nbest);
140        end;
141        if new.loglik>best.loglik;
142           best = new;
143        end;
144      end;
145     
146      % Break condition if likelihood does not change.
147      if best.loglik <= last.loglik;
148          break;
149      else
150          % Making best structure last structure.
151          last = best;
152      end;
153     
154   end;
155
156   % DEBUGging print:
157   %fprintf('STRUCTURE found (local maxima) in loop %2i was %s randun_seed=%11lu randun_counter=%4lu\n', n_start, strPrintstr(best), randn('seed'), RANDUN_COUNTER);
158   
159   % Collecting of the best structure in case we don't need the second parameter
160   if nbest<=1;
161      if best.loglik>global_best.loglik;
162         global_best = best;
163      end;
164   end;
165
166   % uniqueness of the structure found
167   if ~ismember(best.bitstr,local_max,'rows');
168      local_max = [local_max ; best.bitstr];
169      muto = muto + 1;
170   end;   
171   
172   % stopping rule:
173   maxmuto = (to-order_k-1)/lambda-to+1;
174   if to>2;
175      if maxmuto>=muto;
176          % fprintf('*');
177          break;
178      end;
179   end;     
180
181   % do statistics if necessary:
182   if nargout>=3;
183      mutos(to)    = muto;
184      maxmutos(to) = maxmuto;
185   end;
186end;
187
188% Aftermath: The best structure was in: global_best
189
190% Updating loglikelihoods: we have to add the constant stuff
191for f=1:length(global_best);
192   global_best(f).loglik = global_best(f).loglik + seloglik2(global_best(f));
193end;
194
195% Making first output parameter:
196[lik i] = max([global_best.loglik]);
197best = global_best(i);
198strout = best.strRgr;
199
200% Making the second output parameter
201[lik i] = sort([global_best.loglik]);
202rgrsout = global_best(i(length(i):-1:1));
203
204if (nargout>=3);
205   statistics.allstrs = 2^(n_data -1-length(belief_in) - length(belief_out));
206   statistics.nrand   = to-2;
207   statistics.unique  = muto;
208   statistics.to  = to;
209   statistics.cputime_seconds = cputime - cputime0;
210   statistics.itemspeed       = statistics.to / statistics.cputime_seconds;
211   statistics.muto = muto;
212   statistics.mutos = mutos;
213   statistics.maxmutos = maxmutos;
214end;
215
216% randun seed stuff:
217%randn('seed',SEED);
218
219% --------------------- END of MAIN program --------------------
220
221% This is needed for bitstr manipulations
222function out = str_bitset(in,ns,nbits)
223   out = in;
224   for n = ns;
225      index = 1+floor((n-2)/nbits);
226      bitindex = 1+rem(n-2,nbits);
227      out(index) = bitset(out(index),bitindex);
228   end;   
229function out = str_bitres(in,ns,nbits)
230   out = in;
231   for n = ns;
232      index = 1+floor((n-2)/nbits);
233      bitindex = 1+rem(n-2,nbits);
234      mask = bitset(0,bitindex);
235      out(index) = bitxor(bitor(out(index),mask),mask);
236   end;
237
238function out = strPrintstr(in)
239   out = '0';
240   nbits = in.nbits;
241   for f = 2:length(in.d0);
242      index = 1+floor((f-2)/nbits);
243      bitindex = 1+rem(f-2,nbits);
244      if bitget(in.bitstr(index),bitindex);
245          out(f) = '1';
246      else;
247          out(f) = '0';
248      end;
249   end;
250
251function global_best_out = add_new(global_best,new,nbest)
252% Eventually add to global best, but do not go over nbest values
253% Also avoids repeating things, which makes this function awfully slow
254   if length(global_best)>=nbest;
255      logliks = [global_best.loglik];
256      [loglik i] = min(logliks);
257      global_best_out = global_best;
258      if loglik<new.loglik;
259         %         if ~any(logliks == new.loglik);
260         addit=1;
261         for f = [global_best.bitstr];
262            if f == new.bitstr;
263               addit = 0;
264               break;
265            end;
266         end;         
267         if addit;
268            global_best_out(i) = new;
269            % DEBUGging print:
270            % fprintf('ADDED structure, add_new: %s, loglik=%g\n', strPrintstr(new), new.loglik);
271         end;         
272      end;
273   else;
274      global_best_out = [global_best new];
275   end;
276
277function out = sestrremove(in,removed_elements);
278% Removes elements from regressor
279   n_strL = length(in.strL);
280   out = in;
281   for f=removed_elements;
282      posit1 = find(out.strL==1);
283      positf = find(out.strL==f);
284      for g=(positf-1):-1:posit1;
285         % BEGIN: We are swapping g and g+1 NOW!!!!
286         [out.L, out.d]   = seswapudl(out.L, out.d, g);
287         [out.L0, out.d0]   = seswapudl(out.L0, out.d0, g);
288         out.strL([g g+1]) = [out.strL(g+1) out.strL(g)];
289         % END
290      end;
291   end;
292   out.posit1 = find(out.strL==1);
293   out.strRgr = out.strL((out.posit1+1):n_strL);
294   out.strMis = out.strL(1:(out.posit1-1));
295   out.bitstr = str_bitres(out.bitstr,removed_elements,out.nbits);
296   out.loglik = seloglik1(out);
297   
298function out = sestrinsert(in,inserted_elements);
299% Moves elements into regressor
300   n_strL = length(in.strL);
301   out = in;
302   for f=inserted_elements;
303      posit1 = find(out.strL==1);
304      positf = find(out.strL==f);
305      for g=positf:(posit1-1);
306          % BEGIN: We are swapping g and g+1 NOW!!!!
307          [out.L,  out.d]   = seswapudl(out.L,  out.d,  g);
308          [out.L0, out.d0]  = seswapudl(out.L0, out.d0, g);
309          out.strL([g g+1]) = [out.strL(g+1) out.strL(g)];
310          % END
311      end;
312   end;         
313   out.posit1 = find(out.strL==1);
314   out.strRgr = out.strL((out.posit1+1):n_strL);
315   out.strMis = out.strL(1:(out.posit1-1));
316   out.bitstr = str_bitset(out.bitstr,inserted_elements,out.nbits);
317   out.loglik = seloglik1(out);
318
319%
320% seloglik_real = seloglik1 + seloglik2
321%
322
323function l = seloglik1(in)
324% This is the loglikelihood (non-constant part) - this should be used in
325% frequent computation
326   len = length(in.d);
327   p1  = in.posit1;
328     
329   i1 = -0.5*in.nu *log(in.d (p1)) -0.5*sum(log(in.d ((p1+1):len)));
330   i0 = -0.5*in.nu0*log(in.d0(p1)) -0.5*sum(log(in.d0((p1+1):len)));   
331   l  = i1-i0;
332   
333   % DEBUGGing print:
334   % fprintf('SELOGLIK1: str=%s loglik=%g\n', strPrintstr(in), l);
335
336
337function l = seloglik2(in)
338% This is the loglikelihood (constant part) - this should be added to
339% everything at the end. It needs some computation, so it is useless to
340% make it for all the stuff
341   logpi = log(pi);
342
343   i1 = gammaln(in.nu /2) - 0.5*in.nu *logpi;
344   i0 = gammaln(in.nu0/2) - 0.5*in.nu0*logpi;
345   l  = i1-i0;
346
347
348function [Lout, dout] = seswapudl(L,d,i);
349%SESWAPUDL swaps information matrix in decomposition V=L^T diag(d) L
350%
351%  [Lout, dout] = seswapudl(L,d,i);
352%
353% L : lower triangular matrix with 1's on diagonal of the decomposistion
354% d : diagonal vector of diagonal matrix of the decomposition
355% i : index of line to be swapped with the next one
356% Lout : output lower triangular matrix
357% dout : output diagional vector of diagonal matrix D
358%
359% Description:
360%  Lout' * diag(dout) * Lout = P(i,i+1) * L' * diag(d) * L * P(i,i+1);
361%
362%  Where permutation matrix P(i,j) permutates columns if applied from the
363%  right and line if applied from the left.
364%   
365% Note: naming:
366%       se = structure estimation
367%       lite = light, simple
368%       udl = U*D*L, or more precisely, L'*D*L, also called as ld
369%   
370% Design  : L. Tesar
371% Updated : Feb 2003
372% Project : post-ProDaCTool
373% Reference: sedydr
374   
375j = i+1;
376
377pomd = d(i);
378d(i) = d(j);
379d(j) = pomd;
380
381pomL   = L(i,:);
382L(i,:) = L(j,:);
383L(j,:) = pomL;
384
385pomL   = L(:,i);
386L(:,i) = L(:,j);
387L(:,j) = pomL;
388
389% We must be working with LINES of matrix L !
390
391r  = L(i,:)';
392f  = L(j,:)';
393Dr = d(i);
394Df = d(j);
395
396[r, f, Dr, Df] = sedydr(r, f, Dr, Df, j);
397
398r0 = r(i);
399Dr = Dr*r0*r0;
400r  = r/r0;
401
402L(i,:) = r';
403L(j,:) = f';
404d(i)   = Dr;
405d(j)   = Df;
406
407L(i,i) = 1;
408L(j,j) = 1;
409
410Lout = L;
411dout = d;
412
413function [rout, fout, Drout, Dfout, kr] = sedydr(r,f,Dr,Df,R,jl,jh);
414%SEDYDR dyadic reduction, performs transformation of sum of 2 dyads
415%
416% [rout, fout, Drout, Dfout, kr] = sedydr(r,f,Dr,Df,R,jl,jh);
417% [rout, fout, Drout, Dfout] = sedydr(r,f,Dr,Df,R);
418%
419% Description: dyadic reduction, performs transformation of sum of
420%  2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed by R is zeroed
421%
422%     r    : column vector of reduced dyad
423%     f    : column vector of reducing dyad
424%     Dr   : scalar with weight of reduced dyad
425%     Df   : scalar with weight of reducing dyad
426%     R    : scalar number giving 1 based index to the element of r,
427%            which is to be reduced to
428%            zero; the corresponding element of f is assumed to be 1.
429%     jl   : lower index of the range within which the dyads are
430%            modified (can be omitted, then everything is updated)
431%     jh   : upper index of the range within which the dyads are
432%            modified (can be omitted then everything is updated)
433%     rout,fout,Drout,dfout : resulting two dyads
434%     kr   : coefficient used in the transformation of r
435%            rnew = r + kr*f
436%
437% Description: dyadic reduction, performs transformation of sum of
438%            2 dyads r*Dr*r'+ f*Df*f' so that the element of r indexed by R is zeroed
439% Remark1: Constant mzero means machine zero and should be modified
440%           according to the precision of particular machine
441% Remark2: jl and jh are, in fact, obsolete. It takes longer time to
442%           compute them compared to plain version. The reason is that we
443%           are doing vector operations in m-file. Other reason is that
444%           we need to copy whole vector anyway. It can save half of time for
445%           c-file, if you use it correctly. (please do tests)
446%
447% Note: naming:
448%       se = structure estimation
449%       dydr = dyadic reduction
450%
451% Original Fortran design: V. Peterka 17-7-89
452% Modified for c-language: probably R. Kulhavy
453% Modified for m-language: L. Tesar 2/2003
454% Updated: Feb 2003
455% Project: post-ProDaCTool
456% Reference: none
457   
458if nargin<6;
459   update_whole=1;
460else
461   update_whole=0;
462end;
463
464mzero = 1e-32;
465
466if Dr<mzero;
467   Dr=0;
468end;
469
470r0   = r(R);
471kD   = Df;
472kr   = r0 * Dr;
473Dfout   = kD + r0 * kr;
474
475if Dfout > mzero;
476    kD = kD / Dfout;
477    kr = kr / Dfout;
478else;
479    kD = 1;
480    kr = 0;
481end;
482
483Drout = Dr * kD;
484
485% Try to uncomment marked stuff (*) if in numerical problems, but I don't
486% think it can make any difference for normal healthy floating-point unit
487if update_whole;
488   rout = r - r0*f;
489%   rout(R) = 0;   % * could be needed for some nonsense cases(or numeric reasons?), normally not
490   fout = f + kr*rout;
491%   fout(R) = 1;   % * could be needed for some nonsense cases(or numeric reasons?), normally not
492else; 
493   rout = r;
494   fout = f;
495   rout(jl:jh) = r(jl:jh) - r0 * f(jl:jh);
496   rout(R) = 0;   
497   fout(jl:jh) = f(jl:jh) + kr * rout(jl:jh);
498end;
499
500
501
502
503
504
Note: See TracBrowser for help on using the browser.