root/library/bdm/estim/mixtures.cpp @ 943

Revision 943, 3.8 kB (checked in by smidl, 14 years ago)

syntax of epredictor

  • Property svn:eol-style set to native
Line 
1#include <vector>
2#include "mixtures.h"
3
4namespace bdm {
5
6
7void MixEF::init ( BMEF* Com0, const mat &Data, const int c ) {
8        //prepare sizes
9        Coms.set_size ( c );
10        n = c;
11        weights.set_parameters ( ones ( c ) ); //assume at least one observation in each comp.
12        //est will be done at the end
13        //
14        int i;
15        int ndat = Data.cols();
16        //Estimate  Com0 from all data
17        Coms ( 0 ) = (BMEF*) Com0->_copy();
18//      Coms(0)->set_evalll(false);
19        Coms ( 0 )->bayes_batch ( Data );
20        // Flatten it to its original shape
21        Coms ( 0 )->flatten ( Com0 );
22
23        //Copy it to the rest
24        for ( i = 1; i < n; i++ ) {
25                //copy Com0 and create new rvs for them
26                Coms ( i ) =  (BMEF*) Coms ( 0 )->_copy ( );
27        }
28        //Pick some data for each component and update it
29        for ( i = 0; i < n; i++ ) {
30                //pick one datum
31                int ind = (int) floor ( ndat * UniRNG.sample() );
32                Coms ( i )->bayes ( Data.get_col ( ind ), empty_vec );
33                //flatten back to oringinal
34                Coms ( i )->flatten ( Com0 );
35        }
36
37}
38
39void MixEF::bayes_batch ( const mat &data , const mat &cond, const vec &wData ) {
40        int ndat = data.cols();
41        int t, i, niter;
42        bool converged = false;
43
44        multiBM weights0 ( weights );
45
46        Array<BMEF*> Coms0 ( n );
47        for ( i = 0; i < n; i++ ) {
48                Coms0 ( i ) = ( BMEF* ) Coms ( i )->_copy();
49        }
50
51        niter = 0;
52        mat W = ones ( n, ndat ) / n;
53        mat Wlast = ones ( n, ndat ) / n;
54        vec w ( n );
55        vec ll ( n );
56        // tmp for weights
57        vec wtmp = zeros ( n );
58        int maxi;
59        double maxll;
60        //Estim
61        while ( !converged ) {
62                // Copy components back to their initial values
63                // All necessary information is now in w and Coms0.
64                Wlast = W;
65                //
66                //#pragma omp parallel for
67                for ( t = 0; t < ndat; t++ ) {
68                        //#pragma omp parallel for
69                        for ( i = 0; i < n; i++ ) {
70                                ll ( i ) = Coms ( i )->logpred ( data.get_col ( t ) );
71                                wtmp = 0.0;
72                                wtmp ( i ) = 1.0;
73                                ll ( i ) += weights.logpred ( wtmp );
74                        }
75
76                        maxll = max ( ll, maxi );
77                        switch ( method ) {
78                        case QB:
79                                w = exp ( ll - maxll );
80                                w /= sum ( w );
81                                break;
82                        case EM:
83                                w = 0.0;
84                                w ( maxi ) = 1.0;
85                                break;
86                        }
87
88                        W.set_col ( t, w );
89                }
90
91                // copy initial statistics
92                //#pragma omp parallel for
93                for ( i = 0; i < n; i++ ) {
94                        Coms ( i )-> set_statistics ( Coms0 ( i ) );
95                }
96                weights.set_statistics ( &weights0 );
97
98                // Update statistics
99                // !!!!    note  wData ==> this is extra weight of the data record
100                // !!!!    For typical cases wData=1.
101                for ( t = 0; t < ndat; t++ ) {
102                        //#pragma omp parallel for
103                        for ( i = 0; i < n; i++ ) {
104                                Coms ( i )-> bayes_weighted ( data.get_col ( t ), empty_vec, W ( i, t ) * wData ( t ) );
105                        }
106                        weights.bayes ( W.get_col ( t ) * wData ( t ) );
107                }
108
109                niter++;
110                //TODO better convergence rule.
111                converged = ( niter > 10 );//( sumsum ( abs ( W-Wlast ) ) /n<0.1 );
112        }
113
114        //Clean Coms0
115        for ( i = 0; i < n; i++ ) {
116                delete Coms0 ( i );
117        }
118}
119
120void MixEF::bayes ( const vec &data, const vec &cond = empty_vec ) {
121
122};
123
124void MixEF::bayes ( const mat &data, const vec &cond = empty_vec ) {
125        this->bayes_batch ( data, cond, ones ( data.cols() ) );
126};
127
128
129double MixEF::logpred ( const vec &yt ) const {
130
131        vec w = weights.posterior().mean();
132        double exLL = 0.0;
133        for ( int i = 0; i < n; i++ ) {
134                exLL += w ( i ) * exp ( Coms ( i )->logpred ( yt ) );
135        }
136        return log ( exLL );
137}
138
139emix* MixEF::epredictor ( const vec &vec) const {
140        Array<shared_ptr<epdf> > pC ( n );
141        for ( int i = 0; i < n; i++ ) {
142                pC ( i ) = Coms ( i )->epredictor ( );
143                pC (i) -> set_rv(_yrv());
144        }
145        emix* tmp;
146        tmp = new emix( );
147        tmp->_w() = weights.posterior().mean();
148        tmp->_Coms() = pC;
149        tmp->validate();
150        return tmp;
151}
152
153void MixEF::flatten ( const BMEF* M2 ) {
154        const MixEF* Mix2 = dynamic_cast<const MixEF*> ( M2 );
155        bdm_assert_debug ( Mix2->n == n, "Different no of coms" );
156        //Flatten each component
157        for ( int i = 0; i < n; i++ ) {
158                Coms ( i )->flatten ( Mix2->Coms ( i ) );
159        }
160        //Flatten weights = make them equal!!
161        weights.set_statistics ( & ( Mix2->weights ) );
162}
163}
Note: See TracBrowser for help on using the browser.