1 | #include "mixef.h" |
---|
2 | #include <vector> |
---|
3 | |
---|
4 | using namespace itpp; |
---|
5 | |
---|
6 | |
---|
7 | void MixEF::init ( BMEF* Com0, const mat &Data, int c ) { |
---|
8 | //prepare sizes |
---|
9 | Coms.set_size ( c ); |
---|
10 | n=c; |
---|
11 | weights.set_parameters ( 1/ ( double ) c*ones ( c ) ); //assume at least one observation in each comp. |
---|
12 | //est will be done at the end |
---|
13 | // |
---|
14 | int i; |
---|
15 | int ndat = Data.cols(); |
---|
16 | //Estimate Com0 from all data |
---|
17 | Coms ( 0 ) = ( BMEF* ) Com0->_copy_(); |
---|
18 | // Coms(0)->set_evalll(false); |
---|
19 | Coms ( 0 )->bayesB ( Data ); |
---|
20 | // Flatten it to its original shape |
---|
21 | Coms ( 0 )->flatten ( Com0 ); |
---|
22 | |
---|
23 | //Copy it to the rest |
---|
24 | for ( i=1;i<n;i++ ) { |
---|
25 | //copy Com0 and create new rvs for them |
---|
26 | Coms ( i ) = ( BMEF* ) Coms ( 0 )->_copy_ ( true ); |
---|
27 | } |
---|
28 | //Pick some data for each component and update it |
---|
29 | for ( i=0;i<n;i++ ) { |
---|
30 | //pick one datum |
---|
31 | int ind=ndat*UniRNG.sample(); |
---|
32 | Coms ( i )->bayes ( Data.get_col ( ind ),ndat/n ); |
---|
33 | } |
---|
34 | |
---|
35 | //est already exists - must be deleted before build_est() can be used |
---|
36 | delete est; |
---|
37 | build_est(); |
---|
38 | |
---|
39 | } |
---|
40 | void MixEF::bayesB ( const mat &Data ) { |
---|
41 | this->bayes ( Data ); |
---|
42 | } |
---|
43 | |
---|
44 | void MixEF::bayes ( const vec &data ) { |
---|
45 | |
---|
46 | }; |
---|
47 | |
---|
48 | void MixEF::bayes ( const mat &data ) { |
---|
49 | int ndat=data.cols(); |
---|
50 | int t,i,niter; |
---|
51 | bool converged; |
---|
52 | |
---|
53 | multiBM weights0 ( weights ); |
---|
54 | |
---|
55 | Array<BMEF*> Coms0 ( n ); |
---|
56 | for ( i=0;i<n;i++ ) {Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy_();} |
---|
57 | |
---|
58 | niter=0; |
---|
59 | mat W=ones ( n,ndat ) / n; |
---|
60 | mat Wlast=ones ( n,ndat ) / n; |
---|
61 | vec w ( n ); |
---|
62 | vec ll ( n ); |
---|
63 | // tmp for weights |
---|
64 | vec wtmp = zeros ( n ); |
---|
65 | //Estim |
---|
66 | while ( !converged ) { |
---|
67 | // Copy components back to their initial values |
---|
68 | // All necessary information is now in w and Coms0. |
---|
69 | Wlast = W; |
---|
70 | // |
---|
71 | for ( t=0;t<ndat;t++ ) { |
---|
72 | for ( i=0;i<n;i++ ) { |
---|
73 | ll ( i ) =Coms ( i )->logpred ( data.get_col ( t ) ); |
---|
74 | wtmp =0.0; wtmp ( i ) =1.0; |
---|
75 | ll ( i ) += weights.logpred ( wtmp ); |
---|
76 | } |
---|
77 | w = exp ( ll-max ( ll ) ); |
---|
78 | W.set_col ( t, w/sum ( w ) ); |
---|
79 | } |
---|
80 | |
---|
81 | for ( i=0;i<n;i++ ) { |
---|
82 | Coms ( i )-> set_statistics ( Coms0 ( i ) ); |
---|
83 | } |
---|
84 | weights.set_statistics ( &weights0 ); |
---|
85 | |
---|
86 | for ( t=0;t<ndat;t++ ) { |
---|
87 | for ( i=0;i<n;i++ ) { |
---|
88 | Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) ); |
---|
89 | } |
---|
90 | weights.bayes ( W.get_col ( t ) ); |
---|
91 | } |
---|
92 | |
---|
93 | niter++; |
---|
94 | //TODO better convergence rule. |
---|
95 | converged = ( sumsum ( abs ( W-Wlast ) ) /n<0.001 ); |
---|
96 | } |
---|
97 | |
---|
98 | //Clean Coms0 |
---|
99 | for ( i=0;i<n;i++ ) {delete Coms0 ( i );} |
---|
100 | }; |
---|
101 | |
---|
102 | |
---|
103 | double MixEF::logpred ( const vec &dt ) const { |
---|
104 | |
---|
105 | vec w=weights._epdf().mean(); |
---|
106 | double exLL=0.0; |
---|
107 | for ( int i=0;i<n;i++ ) { |
---|
108 | exLL+=w ( i ) *exp ( Coms ( i )->logpred ( dt ) ); |
---|
109 | } |
---|
110 | return log ( exLL ); |
---|
111 | } |
---|