root/library/bdm/estim/mixtures.cpp @ 1064

Revision 1064, 5.3 kB (checked in by mido, 14 years ago)

astyle applied all over the library

  • Property svn:eol-style set to native
Line 
1#include <vector>
2#include "mixtures.h"
3
4namespace bdm {
5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) {
8    //prepare sizes
9    Coms.set_size ( c );
10    weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
11    multiBM weights0(weights);
12    //est will be done at the end
13    //
14    int i;
15    int ndat = Data.cols();
16    //Estimate  Com0 from all data
17    Coms ( 0 ) = (BMEF*) Com0->_copy();
18//      Coms(0)->set_evalll(false);
19    Coms ( 0 )->bayes_batch ( Data );
20
21    Coms ( 0 )->flatten ( Com0 );
22
23    //Copy it to the rest
24    for ( i = 1; i < Coms.length(); i++ ) {
25        //copy Com0 and create new rvs for them
26        Coms ( i ) =  (BMEF*) Coms ( 0 )->_copy ( );
27    }
28    //Pick some data for each component and update it
29    for ( i = 0; i < Coms.length(); i++ ) {
30        //pick one datum
31        if (ndat==Coms.length()) { //take the ith vector
32            Coms ( i )->bayes ( Data.get_col ( i ), empty_vec );
33        } else { // pick at random
34            int ind = (int) floor ( ndat * UniRNG.sample() );
35            Coms ( i )->bayes_weighted ( Data.get_col ( ind ), empty_vec, ndat );
36            Coms (i)->flatten(Com0,ndat/Coms.length());
37        }
38        //sharpen to the sharp component
39        //Coms ( i )->flatten ( SharpCom.get(), 1.0/Coms.length() );
40    }
41    MixEF_options old_opt =options;
42    MixEF_options ini_opt=options;
43    ini_opt.method = EM;
44    ini_opt.max_niter= 1;
45    bayes_batch(Data, empty_vec);
46
47    for ( i = 0; i < Coms.length(); i++ ) {
48        Coms (i)->flatten(Com0,ndat/Coms.length());
49    }
50
51    options = old_opt;
52}
53
54double MixEF::bayes_batch_weighted ( const mat &data , const mat &cond, const vec &wData ) {
55    int ndat = data.cols();
56    int t, i, niter;
57    bool converged = false;
58
59    multiBM weights0 ( weights );
60
61    int n = Coms.length();
62    Array<BMEF*> Coms0 ( n );
63    for ( i = 0; i < n; i++ ) {
64        Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy();
65    }
66
67    niter = 0;
68    mat W = ones ( n, ndat ) / n;
69    mat Wlast = ones ( n, ndat ) / n;
70    vec w ( n );
71    vec ll ( n );
72    // tmp for weights
73    vec wtmp = zeros ( n );
74    int maxi;
75    double maxll;
76
77    double levid=0.0;
78    //Estim
79    while ( !converged ) {
80        levid=0.0;
81        // Copy components back to their initial values
82        // All necessary information is now in w and Coms0.
83        Wlast = W;
84        //
85        //#pragma omp parallel for
86        for ( t = 0; t < ndat; t++ ) {
87            //#pragma omp parallel for
88            for ( i = 0; i < n; i++ ) {
89                ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) , empty_vec);
90                wtmp = 0.0;
91                wtmp ( i ) = 1.0;
92                ll ( i ) += weights.logpred ( wtmp );
93            }
94
95            maxll = max ( ll, maxi );
96            switch ( options.method ) {
97            case QB:
98                w = exp ( ll - maxll );
99                w /= sum ( w );
100                break;
101            case EM:
102                w = 0.0;
103                w ( maxi ) = 1.0;
104                break;
105            }
106
107            W.set_col ( t, w );
108        }
109
110        // copy initial statistics
111        //#pragma omp parallel for
112        for ( i = 0; i < n; i++ ) {
113            Coms ( i )-> set_statistics ( Coms0 ( i ) );
114        }
115        weights.set_statistics ( &weights0 );
116
117        // Update statistics
118        // !!!!    note  wData ==> this is extra weight of the data record
119        // !!!!    For typical cases wData=1.
120        vec logevid(n);
121        for ( t = 0; t < ndat; t++ ) {
122            //#pragma omp parallel for
123            for ( i = 0; i < n; i++ ) {
124                Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) );
125                logevid(i) = Coms(i)->_ll();
126            }
127            weights.bayes ( W.get_col ( t ) * wData ( t ) );
128        }
129        levid += weights._ll()+log(weights.posterior().mean() * exp(logevid)); // inner product w*exp(evid)
130
131        niter++;
132        //TODO better convergence rule.
133        converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
134    }
135
136    //Clean Coms0
137    for ( i = 0; i < n; i++ ) {
138        delete Coms0 ( i );
139    }
140    return levid;
141}
142
143void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) {
144
145};
146
147double MixEF::logpred ( const vec &yt, const vec &cond =empty_vec) const {
148
149    vec w = weights.posterior().mean();
150    double exLL = 0.0;
151    for ( int i = 0; i < Coms.length(); i++ ) {
152        exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt , cond ) );
153    }
154    return log ( exLL );
155}
156
157emix* MixEF::epredictor ( const vec &vec) const {
158    Array<shared_ptr<epdf> > pC ( Coms.length() );
159    for ( int i = 0; i < Coms.length(); i++ ) {
160        pC ( i ) = Coms ( i )->epredictor ( );
161        pC (i) -> set_rv(_yrv());
162    }
163    emix* tmp;
164    tmp = new emix( );
165    tmp->_w() = weights.posterior().mean();
166    tmp->_Coms() = pC;
167    tmp->validate();
168    return tmp;
169}
170
171void MixEF::flatten ( const BMEF* M2, double weight=1.0 ) {
172    const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 );
173    bdm_assert_debug ( Mix2->Coms.length() == Coms.length(), "Different no of coms" );
174    //Flatten each component
175    for ( int i = 0; i < Coms.length(); i++ ) {
176        Coms ( i )->flatten ( Mix2->Coms ( i ) , weight);
177    }
178    //Flatten weights = make them equal!!
179    weights.set_statistics ( & ( Mix2->weights ) );
180}
181}
Note: See TracBrowser for help on using the browser.