root/library/bdm/estim/mixtures.cpp @ 1013

Revision 1013, 4.4 kB (checked in by smidl, 14 years ago)

Flatten has an extra argument

  • Property svn:eol-style set to native
Line 
1#include <vector>
2#include "mixtures.h"
3
4namespace bdm {
5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) {
8        //prepare sizes
9        Coms.set_size ( c );
10        weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
11        multiBM weights0(weights);
12        //est will be done at the end
13        //
14        int i;
15        int ndat = Data.cols();
16        //Estimate  Com0 from all data
17        Coms ( 0 ) = (BMEF*) Com0->_copy();
18//      Coms(0)->set_evalll(false);
19        Coms ( 0 )->bayes_batch ( Data );
20        // Flatten it to its original shape
21        shared_ptr<BMEF> SharpCom((BMEF*)Coms(0)->_copy());
22        Coms ( 0 )->flatten ( Com0 ); 
23
24        //Copy it to the rest
25        for ( i = 1; i < Coms.length(); i++ ) {
26                //copy Com0 and create new rvs for them
27                Coms ( i ) =  (BMEF*) Coms ( 0 )->_copy ( );
28        }
29        //Pick some data for each component and update it
30        for ( i = 0; i < Coms.length(); i++ ) {
31                //pick one datum
32                if (ndat==Coms.length()){ //take the ith vector
33                        Coms ( i )->bayes ( Data.get_col ( i ), empty_vec );
34                } else { // pick at random
35                        int ind = (int) floor ( ndat * UniRNG.sample() );
36                        Coms ( i )->bayes_weighted ( Data.get_col ( ind ), empty_vec, ndat/Coms.length() );
37                }
38                //sharpen to the sharp component
39                //Coms ( i )->flatten ( SharpCom.get(), 1.0/Coms.length() );
40        }
41}
42
43double MixEF::bayes_batch_weighted ( const mat &data , const mat &cond, const vec &wData ) {
44        int ndat = data.cols();
45        int t, i, niter;
46        bool converged = false;
47
48        multiBM weights0 ( weights );
49
50        int n = Coms.length();
51        Array<BMEF*> Coms0 ( n );
52        for ( i = 0; i < n; i++ ) {
53                Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy();
54        }
55
56        niter = 0;
57        mat W = ones ( n, ndat ) / n;
58        mat Wlast = ones ( n, ndat ) / n;
59        vec w ( n );
60        vec ll ( n );
61        // tmp for weights
62        vec wtmp = zeros ( n );
63        int maxi;
64        double maxll;
65       
66        double levid=0.0;
67        //Estim
68        while ( !converged ) {
69                levid=0.0;
70                // Copy components back to their initial values
71                // All necessary information is now in w and Coms0.
72                Wlast = W;
73                //
74                //#pragma omp parallel for
75                for ( t = 0; t < ndat; t++ ) {
76                        //#pragma omp parallel for
77                        for ( i = 0; i < n; i++ ) {
78                                ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) , empty_vec);
79                                wtmp = 0.0;
80                                wtmp ( i ) = 1.0;
81                                ll ( i ) += weights.logpred ( wtmp );
82                        }
83
84                        maxll = max ( ll, maxi );
85                        switch ( method ) {
86                        case QB:
87                                w = exp ( ll - maxll );
88                                w /= sum ( w );
89                                break;
90                        case EM:
91                                w = 0.0;
92                                w ( maxi ) = 1.0;
93                                break;
94                        }
95
96                        W.set_col ( t, w );
97                }
98
99                // copy initial statistics
100                //#pragma omp parallel for
101                for ( i = 0; i < n; i++ ) {
102                        Coms ( i )-> set_statistics ( Coms0 ( i ) );
103                }
104                weights.set_statistics ( &weights0 );
105
106                // Update statistics
107                // !!!!    note  wData ==> this is extra weight of the data record
108                // !!!!    For typical cases wData=1.
109                vec logevid(n);
110                for ( t = 0; t < ndat; t++ ) {
111                        //#pragma omp parallel for
112                        for ( i = 0; i < n; i++ ) {
113                                Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) );
114                                logevid(i) = Coms(i)->_ll();
115                        }
116                        weights.bayes ( W.get_col ( t ) * wData ( t ) );
117                }
118                levid += weights._ll()+log(weights.posterior().mean() * exp(logevid)); // inner product w*exp(evid)
119               
120                niter++;
121                //TODO better convergence rule.
122                converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
123        }
124
125        //Clean Coms0
126        for ( i = 0; i < n; i++ ) {
127                delete Coms0 ( i );
128        }
129        return levid;
130}
131
132void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) {
133
134};
135
136double MixEF::logpred ( const vec &yt, const vec &cond =empty_vec) const {
137
138        vec w = weights.posterior().mean();
139        double exLL = 0.0;
140        for ( int i = 0; i < Coms.length(); i++ ) {
141                exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt , cond ) );
142        }
143        return log ( exLL );
144}
145
146emix* MixEF::epredictor ( const vec &vec) const {
147        Array<shared_ptr<epdf> > pC ( Coms.length() );
148        for ( int i = 0; i < Coms.length(); i++ ) {
149                pC ( i ) = Coms ( i )->epredictor ( );
150                pC (i) -> set_rv(_yrv());
151        }
152        emix* tmp;
153        tmp = new emix( );
154        tmp->_w() = weights.posterior().mean();
155        tmp->_Coms() = pC;
156        tmp->validate();
157        return tmp;
158}
159
160void MixEF::flatten ( const BMEF* M2, double weight=1.0 ) {
161        const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 );
162        bdm_assert_debug ( Mix2->Coms.length() == Coms.length(), "Different no of coms" );
163        //Flatten each component
164        for ( int i = 0; i < Coms.length(); i++ ) {
165                Coms ( i )->flatten ( Mix2->Coms ( i ) , weight);
166        }
167        //Flatten weights = make them equal!!
168        weights.set_statistics ( & ( Mix2->weights ) );
169}
170}
Note: See TracBrowser for help on using the browser.