root/bdm/estim/mixef.cpp @ 311

Revision 311, 3.6 kB (checked in by smidl, 15 years ago)

merger

  • Property svn:eol-style set to native
Line 
1#include "mixef.h"
2#include <vector>
3
4namespace bdm{
5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, int c ) {
8        //prepare sizes
9        Coms.set_size ( c );
10        n=c;
11        weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
12        //est will be done at the end
13        //
14        int i;
15        int ndat = Data.cols();
16        //Estimate  Com0 from all data
17        Coms ( 0 ) = Com0->_copy_();
18//      Coms(0)->set_evalll(false);
19        Coms ( 0 )->bayesB ( Data );
20        // Flatten it to its original shape
21        Coms ( 0 )->flatten ( Com0 );
22
23        //Copy it to the rest
24        for ( i=1;i<n;i++ ) {
25                //copy Com0 and create new rvs for them
26                Coms ( i ) =  Coms ( 0 )->_copy_ ( );
27        }
28        //Pick some data for each component and update it
29        for ( i=0;i<n;i++ ) {
30                //pick one datum
31                int ind=floor(ndat*UniRNG.sample());
32                Coms ( i )->bayes ( Data.get_col ( ind ),1.0 );
33                //flatten back to oringinal
34                Coms(i)->flatten(Com0);
35        }
36
37        //est already exists - must be deleted before build_est() can be used
38        delete est;
39        build_est();
40
41}
42
43void MixEF::bayesB ( const mat &data , const vec &wData ) {
44        int ndat=data.cols();
45        int t,i,niter;
46        bool converged=false;
47
48        multiBM weights0 ( weights );
49
50        Array<BMEF*> Coms0 ( n );
51        for ( i=0;i<n;i++ ) {Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy_();}
52
53        niter=0;
54        mat W=ones ( n,ndat ) / n;
55        mat Wlast=ones ( n,ndat ) / n;
56        vec w ( n );
57        vec ll ( n );
58        // tmp for weights
59        vec wtmp = zeros ( n );
60        int maxi;
61        double maxll;
62        //Estim
63        while ( !converged ) {
64                // Copy components back to their initial values
65                // All necessary information is now in w and Coms0.
66                Wlast = W;
67                //
68                //#pragma omp parallel for
69                for ( t=0;t<ndat;t++ ) {
70                        //#pragma omp parallel for
71                        for ( i=0;i<n;i++ ) {
72                                ll ( i ) =Coms ( i )->logpred ( data.get_col ( t ) );
73                                wtmp =0.0; wtmp ( i ) =1.0;
74                                ll ( i ) += weights.logpred ( wtmp );
75                        }
76                       
77                        maxll = max(ll,maxi);
78                        switch (method) {
79                                case QB:
80                                        w = exp ( ll-maxll );
81                                        w/=sum(w);
82                                        break;
83                                case EM:
84                                        w = 0.0;
85                                        w(maxi) = 1.0;
86                                        break;
87                        }
88                       
89                        W.set_col ( t, w );
90                }
91
92                // copy initial statistics
93                //#pragma omp parallel for
94                for ( i=0;i<n;i++ ) {
95                        Coms ( i )-> set_statistics ( Coms0 ( i ) );
96                }
97                weights.set_statistics ( &weights0 );
98
99                // Update statistics
100                // !!!!    note  wData ==> this is extra weight of the data record
101                // !!!!    For typical cases wData=1.
102                for ( t=0;t<ndat;t++ ) {
103                        //#pragma omp parallel for
104                        for ( i=0;i<n;i++ ) {
105                                Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) * wData ( t ) );
106                        }
107                        weights.bayes ( W.get_col ( t ) * wData ( t ) );
108                }
109
110                niter++;
111                //TODO better convergence rule.
112                converged = (niter>10);//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
113        }
114
115        //Clean Coms0
116        for ( i=0;i<n;i++ ) {delete Coms0 ( i );}
117}
118
119void MixEF::bayes ( const vec &data ) {
120
121};
122
123void MixEF::bayes ( const mat &data ) {
124        this->bayesB ( data, ones ( data.cols() ) );
125};
126
127
128double MixEF::logpred ( const vec &dt ) const {
129
130        vec w=weights.posterior().mean();
131        double exLL=0.0;
132        for ( int i=0;i<n;i++ ) {
133                exLL+=w ( i ) *exp ( Coms ( i )->logpred ( dt ) );
134        }
135        return log ( exLL );
136}
137
138emix* MixEF::epredictor (  ) const {
139        Array<epdf*> pC ( n );
140        for ( int i=0;i<n;i++ ) {pC ( i ) =Coms ( i )->epredictor ( );}
141        emix* tmp;
142        tmp = new emix (  );
143        tmp->set_parameters ( weights.posterior().mean(), pC, false );
144        tmp->ownComs();
145        return tmp;
146}
147
148void MixEF::flatten(const BMEF* M2){
149        const MixEF* Mix2=dynamic_cast<const MixEF*>(M2);
150        it_assert_debug(Mix2->n==n,"Different no of coms");
151        //Flatten each component
152        for (int i=0; i<n;i++){
153                Coms(i)->flatten(Mix2->Coms(i));
154        }
155        //Flatten weights = make them equal!!
156        weights.set_statistics(&(Mix2->weights));
157}
158}
Note: See TracBrowser for help on using the browser.