root/library/bdm/estim/mixtures.cpp @ 1009

Revision 1009, 4.3 kB (checked in by smidl, 14 years ago)

changes in bayes_batch

  • Property svn:eol-style set to native
Line 
1#include <vector>
2#include "mixtures.h"
3
4namespace bdm {
5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) {
8        //prepare sizes
9        Coms.set_size ( c );
10        weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
11        //est will be done at the end
12        //
13        int i;
14        int ndat = Data.cols();
15        //Estimate  Com0 from all data
16        Coms ( 0 ) = (BMEF*) Com0->_copy();
17//      Coms(0)->set_evalll(false);
18        Coms ( 0 )->bayes_batch ( Data );
19        // Flatten it to its original shape
20        Coms ( 0 )->flatten ( Com0 );
21
22        //Copy it to the rest
23        for ( i = 1; i < Coms.length(); i++ ) {
24                //copy Com0 and create new rvs for them
25                Coms ( i ) =  (BMEF*) Coms ( 0 )->_copy ( );
26        }
27        //Pick some data for each component and update it
28        for ( i = 0; i < Coms.length(); i++ ) {
29                //pick one datum
30                if (ndat==Coms.length()){ //take the ith vector
31                        Coms ( i )->bayes ( Data.get_col ( i ), empty_vec );
32                } else { // pick at random
33                        int ind = (int) floor ( ndat * UniRNG.sample() );
34                        Coms ( i )->bayes ( Data.get_col ( ind ), empty_vec );
35                }
36                //flatten back to oringinal
37                Coms ( i )->flatten ( Com0 );
38        }
39
40}
41
42double MixEF::bayes_batch ( const mat &data , const mat &cond, const vec &wData ) {
43        int ndat = data.cols();
44        int t, i, niter;
45        bool converged = false;
46
47        multiBM weights0 ( weights );
48
49        int n = Coms.length();
50        Array<BMEF*> Coms0 ( n );
51        for ( i = 0; i < n; i++ ) {
52                Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy();
53        }
54
55        niter = 0;
56        mat W = ones ( n, ndat ) / n;
57        mat Wlast = ones ( n, ndat ) / n;
58        vec w ( n );
59        vec ll ( n );
60        // tmp for weights
61        vec wtmp = zeros ( n );
62        int maxi;
63        double maxll;
64       
65        double levid=0.0;
66        //Estim
67        while ( !converged ) {
68                levid=0.0;
69                // Copy components back to their initial values
70                // All necessary information is now in w and Coms0.
71                Wlast = W;
72                //
73                //#pragma omp parallel for
74                for ( t = 0; t < ndat; t++ ) {
75                        //#pragma omp parallel for
76                        for ( i = 0; i < n; i++ ) {
77                                ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) , empty_vec);
78                                wtmp = 0.0;
79                                wtmp ( i ) = 1.0;
80                                ll ( i ) += weights.logpred ( wtmp );
81                        }
82
83                        maxll = max ( ll, maxi );
84                        switch ( method ) {
85                        case QB:
86                                w = exp ( ll - maxll );
87                                w /= sum ( w );
88                                break;
89                        case EM:
90                                w = 0.0;
91                                w ( maxi ) = 1.0;
92                                break;
93                        }
94
95                        W.set_col ( t, w );
96                }
97
98                // copy initial statistics
99                //#pragma omp parallel for
100                for ( i = 0; i < n; i++ ) {
101                        Coms ( i )-> set_statistics ( Coms0 ( i ) );
102                }
103                weights.set_statistics ( &weights0 );
104
105                // Update statistics
106                // !!!!    note  wData ==> this is extra weight of the data record
107                // !!!!    For typical cases wData=1.
108                vec logevid(n);
109                for ( t = 0; t < ndat; t++ ) {
110                        //#pragma omp parallel for
111                        for ( i = 0; i < n; i++ ) {
112                                Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) );
113                                logevid(i) = Coms(i)->_ll();
114                        }
115                        weights.bayes ( W.get_col ( t ) * wData ( t ) );
116                }
117                levid += weights._ll()+log(weights.posterior().mean() * exp(logevid)); // inner product w*exp(evid)
118               
119                niter++;
120                //TODO better convergence rule.
121                converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
122        }
123
124        //Clean Coms0
125        for ( i = 0; i < n; i++ ) {
126                delete Coms0 ( i );
127        }
128        return levid;
129}
130
131void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) {
132
133};
134
135void MixEF::bayes ( const mat &data, const vec &cond = empty_vec ) {
136        this->bayes_batch ( data, cond, ones ( data.cols() ) );
137};
138
139
140double MixEF::logpred ( const vec &yt, const vec &cond =empty_vec) const {
141
142        vec w = weights.posterior().mean();
143        double exLL = 0.0;
144        for ( int i = 0; i < Coms.length(); i++ ) {
145                exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt , cond ) );
146        }
147        return log ( exLL );
148}
149
150emix* MixEF::epredictor ( const vec &vec) const {
151        Array<shared_ptr<epdf> > pC ( Coms.length() );
152        for ( int i = 0; i < Coms.length(); i++ ) {
153                pC ( i ) = Coms ( i )->epredictor ( );
154                pC (i) -> set_rv(_yrv());
155        }
156        emix* tmp;
157        tmp = new emix( );
158        tmp->_w() = weights.posterior().mean();
159        tmp->_Coms() = pC;
160        tmp->validate();
161        return tmp;
162}
163
164void MixEF::flatten ( const BMEF* M2 ) {
165        const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 );
166        bdm_assert_debug ( Mix2->Coms.length() == Coms.length(), "Different no of coms" );
167        //Flatten each component
168        for ( int i = 0; i < Coms.length(); i++ ) {
169                Coms ( i )->flatten ( Mix2->Coms ( i ) );
170        }
171        //Flatten weights = make them equal!!
172        weights.set_statistics ( & ( Mix2->weights ) );
173}
174}
Note: See TracBrowser for help on using the browser.