root/library/bdm/stat/merger.cpp @ 741

Revision 741, 10.0 kB (checked in by smidl, 14 years ago)

Stress tests are passing now. Missing validate calls are filled...

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) :
8                Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9        set_sources ( S );
10}
11
12void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) {
13        pdfs = Sources;
14        Nsources = pdfs.length();
15        //set sizes
16        dls.set_size ( Sources.length() );
17        rvzs.set_size ( Sources.length() );
18        zdls.set_size ( Sources.length() );
19
20        rv = get_composite_rv ( pdfs, /* checkoverlap = */ false );
21
22        RV rvc;
23        // Extend rv by rvc!
24        for ( int i = 0; i < pdfs.length(); i++ ) {
25                RV rvx = pdfs ( i )->_rvc().subt ( rv );
26                rvc.add ( rvx ); // add rv to common rvc
27        }
28
29        // join rv and rvc - see descriprion
30        rv.add ( rvc );
31        // get dimension
32        dim = rv._dsize();
33
34        // create links between sources and common rv
35        RV xytmp;
36        for ( int i = 0; i < pdfs.length(); i++ ) {
37                //Establich connection between pdfs and merger
38                dls ( i ) = new datalink_m2e;
39                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv );
40
41                // find out what is missing in each pdf
42                xytmp = pdfs ( i )->_rv();
43                xytmp.add ( pdfs ( i )->_rvc() );
44                // z_i = common_rv-xy
45                rvzs ( i ) = rv.subt ( xytmp );
46                //establish connection between extension (z_i|x,y)s and common rv
47                zdls ( i ) = new datalink_m2e;
48                zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
49        };
50}
51
52void merger_base::set_support ( rectangular_support &Sup ) {
53        Npoints = Sup.points();
54        eSmp.set_parameters ( Npoints, false );
55        Array<vec> &samples = eSmp._samples();
56        eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
57        //set samples
58        samples ( 0 ) = Sup.first_vec();
59        for ( int j = 1; j < Npoints; j++ ) {
60                samples ( j ) = Sup.next_vec();
61        }
62}
63
64void merger_base::merge () {
65        validate();
66
67        //check if sources overlap:
68        bool OK = true;
69        for ( int i = 0; i < pdfs.length(); i++ ) {
70                OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
71                OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
72        }
73
74        if ( OK ) {
75                mat lW = zeros ( pdfs.length(), eSmp._w().length() );
76
77                vec emptyvec ( 0 );
78                for ( int i = 0; i < pdfs.length(); i++ ) {
79                        for ( int j = 0; j < eSmp._w().length(); j++ ) {
80                                lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
81                        }
82                }
83
84                vec w_nn = merge_points ( lW );
85                vec wtmp = exp ( w_nn - max ( w_nn ) );
86                //renormalize
87                eSmp._w() = wtmp / sum ( wtmp );
88        } else {
89                bdm_error ( "Sources are not compatible - use merger_mix" );
90        }
91}
92
93vec merger_base::merge_points ( mat &lW ) {
94        int nu = lW.rows();
95        vec result;
96        ivec indW;
97        bool infexist = false;
98        switch ( METHOD ) {
99        case ARITHMETIC:
100                result = log ( sum ( exp ( lW ) ) ); //ugly!
101                break;
102        case GEOMETRIC:
103                result = sum ( lW ) / nu;
104                break;
105        case LOGNORMAL:
106                vec sumlW = sum ( lW ) ;
107                indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
108                infexist = ( indW.size() < lW.cols() );
109                vec mu;
110                vec lam;
111                if ( infexist ) {
112                        mu = sumlW ( indW ) / nu; //mean of logs
113                        //
114                        mat validlW = lW.get_cols ( indW );
115                        lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
116                } else {
117                        mu = sum ( lW ) / nu; //mean of logs
118                        lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
119                }
120                //
121                double coef = 0.0;
122                vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
123                switch ( nu ) {
124                case 2:
125                        coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
126                        result = coef * sq2bl + mu ;
127                        break;
128                        // case 4: == can be done similar to case 2 - is it worth it???
129                default: // see accompanying document merge_lognorm_derivation.lyx
130                        coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
131                        result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
132                        break;
133                }
134                break;
135        }
136        if ( infexist ) {
137                vec tmp = -inf * ones ( lW.cols() );
138                set_subvector ( tmp, indW, result );
139                return tmp;
140        } else {
141                return result;
142        }
143}
144
145vec merger_base::mean() const {
146        const Vec<double> &w = eSmp._w();
147        const Array<vec> &S = eSmp._samples();
148        vec tmp = zeros ( dim );
149        for ( int i = 0; i < Npoints; i++ ) {
150                tmp += w ( i ) * S ( i );
151        }
152        return tmp;
153}
154
155mat merger_base::covariance() const {
156        const vec &w = eSmp._w();
157        const Array<vec> &S = eSmp._samples();
158
159        vec mea = mean();
160
161//                      cout << sum (w) << "," << w*w << endl;
162
163        mat Tmp = zeros ( dim, dim );
164        for ( int i = 0; i < Npoints; i++ ) {
165                Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
166        }
167        return Tmp - outer_product ( mea, mea );
168}
169
170vec merger_base::variance() const {
171        const vec &w = eSmp._w();
172        const Array<vec> &S = eSmp._samples();
173
174        vec tmp = zeros ( dim );
175        for ( int i = 0; i < Nsources; i++ ) {
176                tmp += w ( i ) * pow ( S ( i ), 2 );
177        }
178        return tmp - pow ( mean(), 2 );
179}
180
181void merger_mix::merge ( ) {
182        Array<vec> &Smp = eSmp._samples(); //aux
183        vec &w = eSmp._w(); //aux
184
185        mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
186        for ( int i = 0; i < Npoints; i++ ) {
187                set_col_part ( Smp_ex, i, Smp ( i ) );
188        }
189
190        if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
191
192        // Stuff for merging
193        vec lw_src ( Npoints );         // weights of the ith source
194        vec lw_mix ( Npoints );         // weights of the approximating mixture
195        vec lw ( Npoints );                     // tmp
196        mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
197        vec vec0 ( 0 );
198
199        //initialize importance weights
200        lw_mix = 1.0; // assuming uniform grid density -- otherwise
201
202        // Initial component in the mixture model
203        mat V0 = 1e-8 * eye ( dim + 1 );
204        ARX A0;
205        A0.set_statistics ( dim, V0 ); //initial guess of Mix:
206        A0.validate();
207
208        Mix.init ( &A0, Smp_ex, Ncoms );
209        //Preserve initial mixture for repetitive estimation via flattening
210        MixEF Mix_init ( Mix );
211
212        // ============= MAIN LOOP ==================
213        bool converged = false;
214        int niter = 0;
215        char dbg_str[100];
216
217        emix* Mpred = Mix.epredictor ( );
218        vec Mix_pdf ( Npoints );
219        while ( !converged ) {
220                //Re-estimate Mix
221                //Re-Initialize Mixture model
222                Mix.flatten ( &Mix_init );
223                Mix.bayes_batch ( Smp_ex, empty_vec, w*Npoints );
224                delete Mpred;
225                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
226                Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
227
228                // This will be active only later in iterations!!!
229                if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
230                        // Generate new samples
231                        eSmp.set_samples ( Mpred );
232                        for ( int i = 0; i < Npoints; i++ ) {
233                                //////////// !!!!!!!!!!!!!
234                                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
235                                set_col_part ( Smp_ex, i, Smp ( i ) );
236                                //Importance of the mixture
237                                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
238                                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
239                        }
240                        if ( DBG ) {
241                                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
242                                cout << Mix.posterior().mean() << endl;
243                                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
244                                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
245                        }
246                }
247                if ( DBG ) {
248                        sprintf ( dbg_str, "Mpred_mean%d", niter );
249                        *dbg_file << Name ( dbg_str ) << Mpred->mean();
250                        sprintf ( dbg_str, "Mpred_var%d", niter );
251                        *dbg_file << Name ( dbg_str ) << Mpred->variance();
252
253
254                        sprintf ( dbg_str, "pdf%d", niter );
255                        for ( int i = 0; i < Npoints; i++ ) {
256                                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );
257                        }
258                        *dbg_file << Name ( dbg_str ) << Mix_pdf;
259
260                        sprintf ( dbg_str, "Smp%d", niter );
261                        *dbg_file << Name ( dbg_str ) << Smp_ex;
262
263                }
264                //Importace weighting
265                for ( int i = 0; i < pdfs.length(); i++ ) {
266                        lw_src = 0.0;
267                        //======== Same RVs ===========
268                        //Split according to dependency in rvs
269                        if ( pdfs ( i )->dimension() == dim ) {
270                                // no need for conditioning or marginalization
271                                lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
272                        } else {
273                                // compute likelihood of marginal on the conditional variable
274                                if ( pdfs ( i )->dimensionc() > 0 ) {
275                                        // Make marginal on rvc_i
276                                        shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() );
277                                        //compute vector of lw_src
278                                        for ( int k = 0; k < Npoints; k++ ) {
279                                                // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
280                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
281                                        }
282
283//                                      sprintf ( str,"marg%d",niter );
284//                                      *dbg << Name ( str ) << lw_src;
285
286                                }
287                                // Compute likelihood of the missing variable
288                                if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) {
289                                        ///////////////
290                                        // There are variales unknown to pdfs(i) : rvzs
291                                        shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
292                                        // Compute likelihood
293                                        vec lw_dbg = lw_src;
294                                        for ( int k = 0; k < Npoints; k++ ) {
295                                                lw_src ( k ) += log (
296                                                                    tmp_cond->evallogcond (
297                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
298                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
299                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
300                                                        lw_src ( k ) = -1e16;
301                                                        cout << "!";
302                                                }
303                                        }
304                                }
305                                // Compute likelihood of the partial source
306                                for ( int k = 0; k < Npoints; k++ ) {
307                                        lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
308                                                        dls ( i )->get_cond ( Smp ( k ) ) );
309                                }
310
311                        }
312
313                        lW.set_row ( i, lw_src ); // do not divide by mix
314                }
315                lw = merger_base::merge_points ( lW ); //merge
316
317                //Importance weighting
318                lw -=  lw_mix; // hoping that it is not numerically sensitive...
319                w = exp ( lw - max ( lw ) );
320
321                //renormalize
322                double sumw = sum ( w );
323                if ( std::isfinite ( sumw ) ) {
324                        w = w / sumw;
325                } else {
326                        it_file itf ( "merg_err.it" );
327                        itf << Name ( "w" ) << w;
328                }
329
330                if ( DBG ) {
331                        sprintf ( dbg_str, "lW%d", niter );
332                        *dbg_file << Name ( dbg_str ) << lW;
333                        sprintf ( dbg_str, "w%d", niter );
334                        *dbg_file << Name ( dbg_str ) << w;
335                        sprintf ( dbg_str, "lw_m%d", niter );
336                        *dbg_file << Name ( dbg_str ) << lw_mix;
337                }
338                // ==== stopping rule ===
339                niter++;
340                converged = ( niter > stop_niter );
341        }
342        delete Mpred;
343//              cout << endl;
344
345}
346
347// DEFAULTS FOR MERGER_BASE
348const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
349const double merger_base::DFLT_beta = 1.2;
350// DEFAULTS FOR MERGER_MIX
351const int merger_mix::DFLT_Ncoms = 10;
352const double merger_mix::DFLT_effss_coef = 0.5;
353
354}
Note: See TracBrowser for help on using the browser.