1 | #include <vector> |
---|
2 | #include "mixtures.h" |
---|
3 | |
---|
4 | namespace bdm { |
---|
5 | |
---|
6 | |
---|
7 | void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) { |
---|
8 | //prepare sizes |
---|
9 | Coms.set_size ( c ); |
---|
10 | n = c; |
---|
11 | weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp. |
---|
12 | //est will be done at the end |
---|
13 | // |
---|
14 | int i; |
---|
15 | int ndat = Data.cols(); |
---|
16 | //Estimate Com0 from all data |
---|
17 | Coms ( 0 ) = (BMEF*) Com0->_copy(); |
---|
18 | // Coms(0)->set_evalll(false); |
---|
19 | Coms ( 0 )->bayes_batch ( Data ); |
---|
20 | // Flatten it to its original shape |
---|
21 | Coms ( 0 )->flatten ( Com0 ); |
---|
22 | |
---|
23 | //Copy it to the rest |
---|
24 | for ( i = 1; i < n; i++ ) { |
---|
25 | //copy Com0 and create new rvs for them |
---|
26 | Coms ( i ) = (BMEF*) Coms ( 0 )->_copy ( ); |
---|
27 | } |
---|
28 | //Pick some data for each component and update it |
---|
29 | for ( i = 0; i < n; i++ ) { |
---|
30 | //pick one datum |
---|
31 | int ind = (int) floor ( ndat * UniRNG.sample() ); |
---|
32 | Coms ( i )->bayes ( Data.get_col ( ind ), empty_vec ); |
---|
33 | //flatten back to oringinal |
---|
34 | Coms ( i )->flatten ( Com0 ); |
---|
35 | } |
---|
36 | |
---|
37 | } |
---|
38 | |
---|
39 | void MixEF::bayes_batch ( const mat &data , const mat &cond, const vec &wData ) { |
---|
40 | int ndat = data.cols(); |
---|
41 | int t, i, niter; |
---|
42 | bool converged = false; |
---|
43 | |
---|
44 | multiBM weights0 ( weights ); |
---|
45 | |
---|
46 | Array<BMEF*> Coms0 ( n ); |
---|
47 | for ( i = 0; i < n; i++ ) { |
---|
48 | Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy(); |
---|
49 | } |
---|
50 | |
---|
51 | niter = 0; |
---|
52 | mat W = ones ( n, ndat ) / n; |
---|
53 | mat Wlast = ones ( n, ndat ) / n; |
---|
54 | vec w ( n ); |
---|
55 | vec ll ( n ); |
---|
56 | // tmp for weights |
---|
57 | vec wtmp = zeros ( n ); |
---|
58 | int maxi; |
---|
59 | double maxll; |
---|
60 | //Estim |
---|
61 | while ( !converged ) { |
---|
62 | // Copy components back to their initial values |
---|
63 | // All necessary information is now in w and Coms0. |
---|
64 | Wlast = W; |
---|
65 | // |
---|
66 | //#pragma omp parallel for |
---|
67 | for ( t = 0; t < ndat; t++ ) { |
---|
68 | //#pragma omp parallel for |
---|
69 | for ( i = 0; i < n; i++ ) { |
---|
70 | ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) ); |
---|
71 | wtmp = 0.0; |
---|
72 | wtmp ( i ) = 1.0; |
---|
73 | ll ( i ) += weights.logpred ( wtmp ); |
---|
74 | } |
---|
75 | |
---|
76 | maxll = max ( ll, maxi ); |
---|
77 | switch ( method ) { |
---|
78 | case QB: |
---|
79 | w = exp ( ll - maxll ); |
---|
80 | w /= sum ( w ); |
---|
81 | break; |
---|
82 | case EM: |
---|
83 | w = 0.0; |
---|
84 | w ( maxi ) = 1.0; |
---|
85 | break; |
---|
86 | } |
---|
87 | |
---|
88 | W.set_col ( t, w ); |
---|
89 | } |
---|
90 | |
---|
91 | // copy initial statistics |
---|
92 | //#pragma omp parallel for |
---|
93 | for ( i = 0; i < n; i++ ) { |
---|
94 | Coms ( i )-> set_statistics ( Coms0 ( i ) ); |
---|
95 | } |
---|
96 | weights.set_statistics ( &weights0 ); |
---|
97 | |
---|
98 | // Update statistics |
---|
99 | // !!!! note wData ==> this is extra weight of the data record |
---|
100 | // !!!! For typical cases wData=1. |
---|
101 | for ( t = 0; t < ndat; t++ ) { |
---|
102 | //#pragma omp parallel for |
---|
103 | for ( i = 0; i < n; i++ ) { |
---|
104 | Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) ); |
---|
105 | } |
---|
106 | weights.bayes ( W.get_col ( t ) * wData ( t ) ); |
---|
107 | } |
---|
108 | |
---|
109 | niter++; |
---|
110 | //TODO better convergence rule. |
---|
111 | converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 ); |
---|
112 | } |
---|
113 | |
---|
114 | //Clean Coms0 |
---|
115 | for ( i = 0; i < n; i++ ) { |
---|
116 | delete Coms0 ( i ); |
---|
117 | } |
---|
118 | } |
---|
119 | |
---|
120 | void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) { |
---|
121 | |
---|
122 | }; |
---|
123 | |
---|
124 | void MixEF::bayes ( const mat &data, const vec &cond = empty_vec ) { |
---|
125 | this->bayes_batch ( data, cond, ones ( data.cols() ) ); |
---|
126 | }; |
---|
127 | |
---|
128 | |
---|
129 | double MixEF::logpred ( const vec &yt ) const { |
---|
130 | |
---|
131 | vec w = weights.posterior().mean(); |
---|
132 | double exLL = 0.0; |
---|
133 | for ( int i = 0; i < n; i++ ) { |
---|
134 | exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt ) ); |
---|
135 | } |
---|
136 | return log ( exLL ); |
---|
137 | } |
---|
138 | |
---|
139 | emix* MixEF::epredictor ( const vec &vec) const { |
---|
140 | Array<shared_ptr<epdf> > pC ( n ); |
---|
141 | for ( int i = 0; i < n; i++ ) { |
---|
142 | pC ( i ) = Coms ( i )->epredictor ( ); |
---|
143 | pC (i) -> set_rv(_yrv()); |
---|
144 | } |
---|
145 | emix* tmp; |
---|
146 | tmp = new emix( ); |
---|
147 | tmp->_w() = weights.posterior().mean(); |
---|
148 | tmp->_Coms() = pC; |
---|
149 | tmp->validate(); |
---|
150 | return tmp; |
---|
151 | } |
---|
152 | |
---|
153 | void MixEF::flatten ( const BMEF* M2 ) { |
---|
154 | const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 ); |
---|
155 | bdm_assert_debug ( Mix2->n == n, "Different no of coms" ); |
---|
156 | //Flatten each component |
---|
157 | for ( int i = 0; i < n; i++ ) { |
---|
158 | Coms ( i )->flatten ( Mix2->Coms ( i ) ); |
---|
159 | } |
---|
160 | //Flatten weights = make them equal!! |
---|
161 | weights.set_statistics ( & ( Mix2->weights ) ); |
---|
162 | } |
---|
163 | } |
---|