| 1 | #include "mixef.h" |
|---|
| 2 | #include <vector> |
|---|
| 3 | |
|---|
| 4 | using namespace itpp; |
|---|
| 5 | |
|---|
| 6 | |
|---|
| 7 | void MixEF::init ( BMEF* Com0, const mat &Data, int c ) { |
|---|
| 8 | //prepare sizes |
|---|
| 9 | Coms.set_size ( c ); |
|---|
| 10 | n=c; |
|---|
| 11 | weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp. |
|---|
| 12 | //est will be done at the end |
|---|
| 13 | // |
|---|
| 14 | int i; |
|---|
| 15 | int ndat = Data.cols(); |
|---|
| 16 | //Estimate Com0 from all data |
|---|
| 17 | Coms ( 0 ) = ( BMEF* ) Com0->_copy_(); |
|---|
| 18 | // Coms(0)->set_evalll(false); |
|---|
| 19 | Coms ( 0 )->bayesB ( Data ); |
|---|
| 20 | // Flatten it to its original shape |
|---|
| 21 | Coms ( 0 )->flatten ( Com0 ); |
|---|
| 22 | |
|---|
| 23 | //Copy it to the rest |
|---|
| 24 | for ( i=1;i<n;i++ ) { |
|---|
| 25 | //copy Com0 and create new rvs for them |
|---|
| 26 | Coms ( i ) = ( BMEF* ) Coms ( 0 )->_copy_ ( true ); |
|---|
| 27 | } |
|---|
| 28 | //Pick some data for each component and update it |
|---|
| 29 | for ( i=0;i<n;i++ ) { |
|---|
| 30 | //pick one datum |
|---|
| 31 | int ind=ndat*UniRNG.sample(); |
|---|
| 32 | Coms ( i )->bayes ( Data.get_col ( ind ),1.0 ); |
|---|
| 33 | } |
|---|
| 34 | |
|---|
| 35 | //est already exists - must be deleted before build_est() can be used |
|---|
| 36 | delete est; |
|---|
| 37 | build_est(); |
|---|
| 38 | |
|---|
| 39 | } |
|---|
| 40 | void MixEF::bayesB ( const mat &Data ) { |
|---|
| 41 | this->bayes ( Data ); |
|---|
| 42 | } |
|---|
| 43 | |
|---|
| 44 | void MixEF::bayes ( const vec &data ) { |
|---|
| 45 | |
|---|
| 46 | }; |
|---|
| 47 | |
|---|
| 48 | void MixEF::bayes ( const mat &data ) { |
|---|
| 49 | int ndat=data.cols(); |
|---|
| 50 | int t,i,niter; |
|---|
| 51 | bool converged; |
|---|
| 52 | |
|---|
| 53 | multiBM weights0 ( weights ); |
|---|
| 54 | |
|---|
| 55 | Array<BMEF*> Coms0 ( n ); |
|---|
| 56 | for ( i=0;i<n;i++ ) {Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy_();} |
|---|
| 57 | |
|---|
| 58 | niter=0; |
|---|
| 59 | mat W=ones ( n,ndat ) / n; |
|---|
| 60 | mat Wlast=ones ( n,ndat ) / n; |
|---|
| 61 | vec w ( n ); |
|---|
| 62 | vec ll ( n ); |
|---|
| 63 | // tmp for weights |
|---|
| 64 | vec wtmp = zeros ( n ); |
|---|
| 65 | //Estim |
|---|
| 66 | while ( !converged ) { |
|---|
| 67 | // Copy components back to their initial values |
|---|
| 68 | // All necessary information is now in w and Coms0. |
|---|
| 69 | Wlast = W; |
|---|
| 70 | // |
|---|
| 71 | for ( t=0;t<ndat;t++ ) { |
|---|
| 72 | for ( i=0;i<n;i++ ) { |
|---|
| 73 | ll ( i ) =Coms ( i )->logpred ( data.get_col ( t ) ); |
|---|
| 74 | wtmp =0.0; wtmp ( i ) =1.0; |
|---|
| 75 | ll ( i ) += weights.logpred ( wtmp ); |
|---|
| 76 | } |
|---|
| 77 | w = exp ( ll-max ( ll ) ); |
|---|
| 78 | W.set_col ( t, w/sum ( w ) ); |
|---|
| 79 | } |
|---|
| 80 | |
|---|
| 81 | for ( i=0;i<n;i++ ) { |
|---|
| 82 | Coms ( i )-> set_statistics ( Coms0 ( i ) ); |
|---|
| 83 | } |
|---|
| 84 | weights.set_statistics ( &weights0 ); |
|---|
| 85 | |
|---|
| 86 | for ( t=0;t<ndat;t++ ) { |
|---|
| 87 | for ( i=0;i<n;i++ ) { |
|---|
| 88 | Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) ); |
|---|
| 89 | } |
|---|
| 90 | weights.bayes ( W.get_col ( t ) ); |
|---|
| 91 | } |
|---|
| 92 | |
|---|
| 93 | niter++; |
|---|
| 94 | //TODO better convergence rule. |
|---|
| 95 | converged = ( sumsum ( abs ( W-Wlast ) ) /n<0.001 ); |
|---|
| 96 | } |
|---|
| 97 | |
|---|
| 98 | //Clean Coms0 |
|---|
| 99 | for ( i=0;i<n;i++ ) {delete Coms0 ( i );} |
|---|
| 100 | }; |
|---|
| 101 | |
|---|
| 102 | |
|---|
| 103 | double MixEF::logpred ( const vec &dt ) const { |
|---|
| 104 | |
|---|
| 105 | vec w=weights._epdf().mean(); |
|---|
| 106 | double exLL=0.0; |
|---|
| 107 | for ( int i=0;i<n;i++ ) { |
|---|
| 108 | exLL+=w ( i ) *exp ( Coms ( i )->logpred ( dt ) ); |
|---|
| 109 | } |
|---|
| 110 | return log ( exLL ); |
|---|
| 111 | } |
|---|
| 112 | |
|---|
| 113 | emix* MixEF::predictor(const RV &rv){ |
|---|
| 114 | Array<epdf*> pC(n); |
|---|
| 115 | for(int i=0;i<n;i++){pC(i)=Coms(i)->predictor(rv);} |
|---|
| 116 | emix* tmp; |
|---|
| 117 | tmp = new emix(rv); |
|---|
| 118 | tmp->set_parameters(weights._epdf().mean(), pC, false); |
|---|
| 119 | tmp->ownComs(); |
|---|
| 120 | return tmp; |
|---|
| 121 | } |
|---|