root/library/bdm/estim/mixtures.cpp @ 679

Revision 679, 3.8 kB (checked in by smidl, 15 years ago)

Major changes in BM -- OK is only test suite and tests/tutorial -- the rest is broken!!!

  • Property svn:eol-style set to native
RevLine 
[176]1#include <vector>
[384]2#include "mixtures.h"
[176]3
[477]4namespace bdm {
[176]5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, int c ) {
8        //prepare sizes
9        Coms.set_size ( c );
[477]10        n = c;
[180]11        weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
[176]12        //est will be done at the end
13        //
14        int i;
15        int ndat = Data.cols();
16        //Estimate  Com0 from all data
[286]17        Coms ( 0 ) = Com0->_copy_();
[176]18//      Coms(0)->set_evalll(false);
19        Coms ( 0 )->bayesB ( Data );
20        // Flatten it to its original shape
21        Coms ( 0 )->flatten ( Com0 );
22
[189]23        //Copy it to the rest
[477]24        for ( i = 1; i < n; i++ ) {
[176]25                //copy Com0 and create new rvs for them
[286]26                Coms ( i ) =  Coms ( 0 )->_copy_ ( );
[176]27        }
28        //Pick some data for each component and update it
[477]29        for ( i = 0; i < n; i++ ) {
[176]30                //pick one datum
[477]31                int ind = floor ( ndat * UniRNG.sample() );
[679]32                Coms ( i )->bayes ( Data.get_col ( ind ), empty_vec );
[197]33                //flatten back to oringinal
[477]34                Coms ( i )->flatten ( Com0 );
[176]35        }
36
37        //est already exists - must be deleted before build_est() can be used
38        delete est;
39        build_est();
40
41}
42
[679]43void MixEF::bayesB ( const mat &data , const mat &cond, const vec &wData ) {
[477]44        int ndat = data.cols();
45        int t, i, niter;
46        bool converged = false;
[176]47
48        multiBM weights0 ( weights );
49
50        Array<BMEF*> Coms0 ( n );
[477]51        for ( i = 0; i < n; i++ ) {
52                Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy_();
53        }
[176]54
[477]55        niter = 0;
56        mat W = ones ( n, ndat ) / n;
57        mat Wlast = ones ( n, ndat ) / n;
[176]58        vec w ( n );
59        vec ll ( n );
60        // tmp for weights
61        vec wtmp = zeros ( n );
[189]62        int maxi;
63        double maxll;
[176]64        //Estim
65        while ( !converged ) {
66                // Copy components back to their initial values
67                // All necessary information is now in w and Coms0.
68                Wlast = W;
69                //
[311]70                //#pragma omp parallel for
[477]71                for ( t = 0; t < ndat; t++ ) {
[311]72                        //#pragma omp parallel for
[477]73                        for ( i = 0; i < n; i++ ) {
74                                ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) );
75                                wtmp = 0.0;
76                                wtmp ( i ) = 1.0;
[176]77                                ll ( i ) += weights.logpred ( wtmp );
78                        }
[477]79
80                        maxll = max ( ll, maxi );
81                        switch ( method ) {
82                        case QB:
83                                w = exp ( ll - maxll );
84                                w /= sum ( w );
85                                break;
86                        case EM:
87                                w = 0.0;
88                                w ( maxi ) = 1.0;
89                                break;
[189]90                        }
[477]91
[189]92                        W.set_col ( t, w );
[176]93                }
94
[189]95                // copy initial statistics
[311]96                //#pragma omp parallel for
[477]97                for ( i = 0; i < n; i++ ) {
[176]98                        Coms ( i )-> set_statistics ( Coms0 ( i ) );
99                }
100                weights.set_statistics ( &weights0 );
101
[189]102                // Update statistics
103                // !!!!    note  wData ==> this is extra weight of the data record
104                // !!!!    For typical cases wData=1.
[477]105                for ( t = 0; t < ndat; t++ ) {
[311]106                        //#pragma omp parallel for
[477]107                        for ( i = 0; i < n; i++ ) {
[679]108                                Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) );
[176]109                        }
[189]110                        weights.bayes ( W.get_col ( t ) * wData ( t ) );
[176]111                }
112
113                niter++;
114                //TODO better convergence rule.
[477]115                converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
[176]116        }
117
118        //Clean Coms0
[477]119        for ( i = 0; i < n; i++ ) {
120                delete Coms0 ( i );
121        }
[189]122}
123
[679]124void MixEF::bayes ( const vec &data, const vec &cond=empty_vec ) {
[189]125
[176]126};
127
[679]128void MixEF::bayes ( const mat &data, const vec &cond=empty_vec ) {
129        this->bayesB ( data, cond, ones ( data.cols() ) );
[189]130};
[176]131
[189]132
[679]133double MixEF::logpred ( const vec &yt ) const {
[176]134
[477]135        vec w = weights.posterior().mean();
136        double exLL = 0.0;
137        for ( int i = 0; i < n; i++ ) {
[679]138                exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt ) );
[176]139        }
140        return log ( exLL );
141}
[180]142
[477]143emix* MixEF::epredictor ( ) const {
[504]144        Array<shared_ptr<epdf> > pC ( n );
[477]145        for ( int i = 0; i < n; i++ ) {
146                pC ( i ) = Coms ( i )->epredictor ( );
147        }
[389]148        emix* tmp;
[477]149        tmp = new emix( );
[504]150        tmp->set_parameters ( weights.posterior().mean(), pC );
[180]151        return tmp;
152}
[197]153
[477]154void MixEF::flatten ( const BMEF* M2 ) {
155        const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 );
[565]156        bdm_assert_debug ( Mix2->n == n, "Different no of coms" );
[197]157        //Flatten each component
[477]158        for ( int i = 0; i < n; i++ ) {
159                Coms ( i )->flatten ( Mix2->Coms ( i ) );
[197]160        }
[198]161        //Flatten weights = make them equal!!
[477]162        weights.set_statistics ( & ( Mix2->weights ) );
[254]163}
[271]164}
Note: See TracBrowser for help on using the browser.