[176] | 1 | #include <vector> |
---|
[384] | 2 | #include "mixtures.h" |
---|
[176] | 3 | |
---|
[477] | 4 | namespace bdm { |
---|
[176] | 5 | |
---|
| 6 | |
---|
[735] | 7 | void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) { |
---|
[1064] | 8 | //prepare sizes |
---|
| 9 | Coms.set_size ( c ); |
---|
| 10 | weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp. |
---|
| 11 | multiBM weights0(weights); |
---|
| 12 | //est will be done at the end |
---|
| 13 | // |
---|
| 14 | int i; |
---|
| 15 | int ndat = Data.cols(); |
---|
| 16 | //Estimate Com0 from all data |
---|
| 17 | Coms ( 0 ) = (BMEF*) Com0->_copy(); |
---|
[176] | 18 | // Coms(0)->set_evalll(false); |
---|
[1064] | 19 | Coms ( 0 )->bayes_batch ( Data ); |
---|
[1014] | 20 | |
---|
[1064] | 21 | Coms ( 0 )->flatten ( Com0 ); |
---|
[176] | 22 | |
---|
[1064] | 23 | //Copy it to the rest |
---|
| 24 | for ( i = 1; i < Coms.length(); i++ ) { |
---|
| 25 | //copy Com0 and create new rvs for them |
---|
| 26 | Coms ( i ) = (BMEF*) Coms ( 0 )->_copy ( ); |
---|
| 27 | } |
---|
| 28 | //Pick some data for each component and update it |
---|
| 29 | for ( i = 0; i < Coms.length(); i++ ) { |
---|
| 30 | //pick one datum |
---|
| 31 | if (ndat==Coms.length()) { //take the ith vector |
---|
| 32 | Coms ( i )->bayes ( Data.get_col ( i ), empty_vec ); |
---|
| 33 | } else { // pick at random |
---|
| 34 | int ind = (int) floor ( ndat * UniRNG.sample() ); |
---|
| 35 | Coms ( i )->bayes_weighted ( Data.get_col ( ind ), empty_vec, ndat ); |
---|
| 36 | Coms (i)->flatten(Com0,ndat/Coms.length()); |
---|
| 37 | } |
---|
| 38 | //sharpen to the sharp component |
---|
| 39 | //Coms ( i )->flatten ( SharpCom.get(), 1.0/Coms.length() ); |
---|
| 40 | } |
---|
[1077] | 41 | Options old_opt =options; |
---|
| 42 | Options ini_opt=options; |
---|
[1064] | 43 | ini_opt.method = EM; |
---|
| 44 | ini_opt.max_niter= 1; |
---|
| 45 | bayes_batch(Data, empty_vec); |
---|
| 46 | |
---|
| 47 | for ( i = 0; i < Coms.length(); i++ ) { |
---|
| 48 | Coms (i)->flatten(Com0,ndat/Coms.length()); |
---|
| 49 | } |
---|
| 50 | |
---|
| 51 | options = old_opt; |
---|
[176] | 52 | } |
---|
| 53 | |
---|
[1013] | 54 | double MixEF::bayes_batch_weighted ( const mat &data , const mat &cond, const vec &wData ) { |
---|
[1064] | 55 | int ndat = data.cols(); |
---|
| 56 | int t, i, niter; |
---|
| 57 | bool converged = false; |
---|
[176] | 58 | |
---|
[1064] | 59 | multiBM weights0 ( weights ); |
---|
[176] | 60 | |
---|
[1064] | 61 | int n = Coms.length(); |
---|
| 62 | Array<BMEF*> Coms0 ( n ); |
---|
| 63 | for ( i = 0; i < n; i++ ) { |
---|
| 64 | Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy(); |
---|
| 65 | } |
---|
[176] | 66 | |
---|
[1064] | 67 | niter = 0; |
---|
| 68 | mat W = ones ( n, ndat ) / n; |
---|
| 69 | mat Wlast = ones ( n, ndat ) / n; |
---|
| 70 | vec w ( n ); |
---|
| 71 | vec ll ( n ); |
---|
| 72 | // tmp for weights |
---|
| 73 | vec wtmp = zeros ( n ); |
---|
| 74 | int maxi; |
---|
| 75 | double maxll; |
---|
[477] | 76 | |
---|
[1064] | 77 | double levid=0.0; |
---|
| 78 | //Estim |
---|
| 79 | while ( !converged ) { |
---|
| 80 | levid=0.0; |
---|
| 81 | // Copy components back to their initial values |
---|
| 82 | // All necessary information is now in w and Coms0. |
---|
| 83 | Wlast = W; |
---|
| 84 | // |
---|
| 85 | //#pragma omp parallel for |
---|
| 86 | for ( t = 0; t < ndat; t++ ) { |
---|
| 87 | //#pragma omp parallel for |
---|
| 88 | for ( i = 0; i < n; i++ ) { |
---|
| 89 | ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) , empty_vec); |
---|
| 90 | wtmp = 0.0; |
---|
| 91 | wtmp ( i ) = 1.0; |
---|
| 92 | ll ( i ) += weights.logpred ( wtmp ); |
---|
| 93 | } |
---|
[477] | 94 | |
---|
[1064] | 95 | maxll = max ( ll, maxi ); |
---|
| 96 | switch ( options.method ) { |
---|
| 97 | case QB: |
---|
| 98 | w = exp ( ll - maxll ); |
---|
| 99 | w /= sum ( w ); |
---|
| 100 | break; |
---|
| 101 | case EM: |
---|
| 102 | w = 0.0; |
---|
| 103 | w ( maxi ) = 1.0; |
---|
| 104 | break; |
---|
| 105 | } |
---|
[176] | 106 | |
---|
[1064] | 107 | W.set_col ( t, w ); |
---|
| 108 | } |
---|
[176] | 109 | |
---|
[1064] | 110 | // copy initial statistics |
---|
| 111 | //#pragma omp parallel for |
---|
| 112 | for ( i = 0; i < n; i++ ) { |
---|
| 113 | Coms ( i )-> set_statistics ( Coms0 ( i ) ); |
---|
| 114 | } |
---|
| 115 | weights.set_statistics ( &weights0 ); |
---|
[176] | 116 | |
---|
[1064] | 117 | // Update statistics |
---|
| 118 | // !!!! note wData ==> this is extra weight of the data record |
---|
| 119 | // !!!! For typical cases wData=1. |
---|
| 120 | vec logevid(n); |
---|
| 121 | for ( t = 0; t < ndat; t++ ) { |
---|
| 122 | //#pragma omp parallel for |
---|
| 123 | for ( i = 0; i < n; i++ ) { |
---|
| 124 | Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) ); |
---|
| 125 | logevid(i) = Coms(i)->_ll(); |
---|
| 126 | } |
---|
| 127 | weights.bayes ( W.get_col ( t ) * wData ( t ) ); |
---|
| 128 | } |
---|
| 129 | levid += weights._ll()+log(weights.posterior().mean() * exp(logevid)); // inner product w*exp(evid) |
---|
| 130 | |
---|
| 131 | niter++; |
---|
| 132 | //TODO better convergence rule. |
---|
| 133 | converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 ); |
---|
| 134 | } |
---|
| 135 | |
---|
| 136 | //Clean Coms0 |
---|
| 137 | for ( i = 0; i < n; i++ ) { |
---|
| 138 | delete Coms0 ( i ); |
---|
| 139 | } |
---|
| 140 | return levid; |
---|
[189] | 141 | } |
---|
| 142 | |
---|
[737] | 143 | void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) { |
---|
[189] | 144 | |
---|
[176] | 145 | }; |
---|
| 146 | |
---|
[1009] | 147 | double MixEF::logpred ( const vec &yt, const vec &cond =empty_vec) const { |
---|
[176] | 148 | |
---|
[1064] | 149 | vec w = weights.posterior().mean(); |
---|
| 150 | double exLL = 0.0; |
---|
| 151 | for ( int i = 0; i < Coms.length(); i++ ) { |
---|
| 152 | exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt , cond ) ); |
---|
| 153 | } |
---|
| 154 | return log ( exLL ); |
---|
[176] | 155 | } |
---|
[180] | 156 | |
---|
[943] | 157 | emix* MixEF::epredictor ( const vec &vec) const { |
---|
[1064] | 158 | Array<shared_ptr<epdf> > pC ( Coms.length() ); |
---|
| 159 | for ( int i = 0; i < Coms.length(); i++ ) { |
---|
| 160 | pC ( i ) = Coms ( i )->epredictor ( ); |
---|
| 161 | pC (i) -> set_rv(_yrv()); |
---|
| 162 | } |
---|
| 163 | emix* tmp; |
---|
| 164 | tmp = new emix( ); |
---|
| 165 | tmp->_w() = weights.posterior().mean(); |
---|
| 166 | tmp->_Coms() = pC; |
---|
| 167 | tmp->validate(); |
---|
| 168 | return tmp; |
---|
[180] | 169 | } |
---|
[197] | 170 | |
---|
[1013] | 171 | void MixEF::flatten ( const BMEF* M2, double weight=1.0 ) { |
---|
[1064] | 172 | const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 ); |
---|
| 173 | bdm_assert_debug ( Mix2->Coms.length() == Coms.length(), "Different no of coms" ); |
---|
| 174 | //Flatten each component |
---|
| 175 | for ( int i = 0; i < Coms.length(); i++ ) { |
---|
| 176 | Coms ( i )->flatten ( Mix2->Coms ( i ) , weight); |
---|
| 177 | } |
---|
| 178 | //Flatten weights = make them equal!! |
---|
| 179 | weights.set_statistics ( & ( Mix2->weights ) ); |
---|
[254] | 180 | } |
---|
[271] | 181 | } |
---|