root/library/bdm/stat/merger.cpp @ 787

Revision 787, 10.1 kB (checked in by smidl, 14 years ago)

arena experiment + numerical fixes

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) :
8                Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9        set_sources ( S );
10}
11
12void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) {
13        pdfs = Sources;
14        Nsources = pdfs.length();
15        //set sizes
16        dls.set_size ( Sources.length() );
17        rvzs.set_size ( Sources.length() );
18        zdls.set_size ( Sources.length() );
19
20        rv = get_composite_rv ( pdfs, /* checkoverlap = */ false );
21
22        RV rvc;
23        // Extend rv by rvc!
24        for ( int i = 0; i < pdfs.length(); i++ ) {
25                RV rvx = pdfs ( i )->_rvc().subt ( rv );
26                rvc.add ( rvx ); // add rv to common rvc
27        }
28
29        // join rv and rvc - see descriprion
30        rv.add ( rvc );
31        // get dimension
32        dim = rv._dsize();
33
34        // create links between sources and common rv
35        RV xytmp;
36        for ( int i = 0; i < pdfs.length(); i++ ) {
37                //Establich connection between pdfs and merger
38                dls ( i ) = new datalink_m2e;
39                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv );
40
41                // find out what is missing in each pdf
42                xytmp = pdfs ( i )->_rv();
43                xytmp.add ( pdfs ( i )->_rvc() );
44                // z_i = common_rv-xy
45                rvzs ( i ) = rv.subt ( xytmp );
46                //establish connection between extension (z_i|x,y)s and common rv
47                zdls ( i ) = new datalink_m2e;
48                zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
49        };
50}
51
52void merger_base::set_support ( rectangular_support &Sup ) {
53        Npoints = Sup.points();
54        eSmp.set_parameters ( Npoints, false );
55        Array<vec> &samples = eSmp._samples();
56        eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
57        //set samples
58        samples ( 0 ) = Sup.first_vec();
59        for ( int j = 1; j < Npoints; j++ ) {
60                samples ( j ) = Sup.next_vec();
61        }
62}
63
64void merger_base::merge () {
65        validate();
66
67        //check if sources overlap:
68        bool OK = true;
69        for ( int i = 0; i < pdfs.length(); i++ ) {
70                OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
71                OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
72        }
73
74        if ( OK ) {
75                mat lW = zeros ( pdfs.length(), eSmp._w().length() );
76
77                vec emptyvec ( 0 );
78                for ( int i = 0; i < pdfs.length(); i++ ) {
79                        for ( int j = 0; j < eSmp._w().length(); j++ ) {
80                                lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
81                        }
82                }
83
84                vec w_nn = merge_points ( lW );
85                vec wtmp = exp ( w_nn - max ( w_nn ) );
86                //renormalize
87                eSmp._w() = wtmp / sum ( wtmp );
88        } else {
89                bdm_error ( "Sources are not compatible - use merger_mix" );
90        }
91}
92
93vec merger_base::merge_points ( mat &lW ) {
94        int nu = lW.rows();
95        vec result;
96        ivec indW;
97        bool infexist = false;
98        switch ( METHOD ) {
99        case ARITHMETIC:
100                result = log ( sum ( exp ( lW ) ) ); //ugly!
101                break;
102        case GEOMETRIC:
103                result = sum ( lW ) / nu;
104                break;
105        case LOGNORMAL:
106                vec sumlW = sum ( lW ) ;
107                indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
108                infexist = ( indW.size() < lW.cols() );
109                vec mu;
110                vec lam;
111                if ( infexist ) {
112                        mu = sumlW ( indW ) / nu; //mean of logs
113                        //
114                        mat validlW = lW.get_cols ( indW );
115                        lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
116                } else {
117                        mu = sum ( lW ) / nu; //mean of logs
118                        lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
119                }
120                //
121                double coef = 0.0;
122                vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
123                switch ( nu ) {
124                case 2:
125                        coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
126                        result = coef * sq2bl + mu ;
127                        break;
128                        // case 4: == can be done similar to case 2 - is it worth it???
129                default: // see accompanying document merge_lognorm_derivation.lyx
130                        coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
131                        result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
132                        break;
133                }
134                break;
135        }
136        if ( infexist ) {
137                vec tmp = -inf * ones ( lW.cols() );
138                set_subvector ( tmp, indW, result );
139                return tmp;
140        } else {
141                return result;
142        }
143}
144
145vec merger_base::mean() const {
146        const Vec<double> &w = eSmp._w();
147        const Array<vec> &S = eSmp._samples();
148        vec tmp = zeros ( dim );
149        for ( int i = 0; i < Npoints; i++ ) {
150                tmp += w ( i ) * S ( i );
151        }
152        return tmp;
153}
154
155mat merger_base::covariance() const {
156        const vec &w = eSmp._w();
157        const Array<vec> &S = eSmp._samples();
158
159        vec mea = mean();
160
161//                      cout << sum (w) << "," << w*w << endl;
162
163        mat Tmp = zeros ( dim, dim );
164        for ( int i = 0; i < Npoints; i++ ) {
165                vec tmp=S ( i )-mea; //inefficient but numerically stable
166                Tmp += w ( i ) * outer_product (tmp , tmp );
167        }
168        return Tmp;
169}
170
171vec merger_base::variance() const {
172        return eSmp.variance();
173}
174
175void merger_mix::merge ( ) {
176        if(Npoints<1){
177                set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000);
178        }
179               
180        bdm_assert(Npoints>0,"No points in support");
181        bdm_assert(Nsources>0,"No Sources");
182       
183        Array<vec> &Smp = eSmp._samples(); //aux
184        vec &w = eSmp._w(); //aux
185
186        mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
187        for ( int i = 0; i < Npoints; i++ ) {
188                set_col_part ( Smp_ex, i, Smp ( i ) );
189        }
190
191        if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
192
193        // Stuff for merging
194        vec lw_src ( Npoints );         // weights of the ith source
195        vec lw_mix ( Npoints );         // weights of the approximating mixture
196        vec lw ( Npoints );                     // tmp
197        mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
198        vec vec0 ( 0 );
199
200        //initialize importance weights
201        lw_mix = 1.0; // assuming uniform grid density -- otherwise
202
203        // Initial component in the mixture model
204        mat V0 = 1e-8 * eye ( dim + 1 );
205        ARX A0;
206        A0.set_statistics ( dim, V0 ); //initial guess of Mix:
207        A0.validate();
208
209        Mix.init ( &A0, Smp_ex, Ncoms );
210        //Preserve initial mixture for repetitive estimation via flattening
211        MixEF Mix_init ( Mix );
212
213        // ============= MAIN LOOP ==================
214        bool converged = false;
215        int niter = 0;
216        char dbg_str[100];
217
218        emix* Mpred = Mix.epredictor ( );
219        vec Mix_pdf ( Npoints );
220        while ( !converged ) {
221                //Re-estimate Mix
222                //Re-Initialize Mixture model
223                Mix.flatten ( &Mix_init );
224                Mix.bayes_batch ( Smp_ex, empty_vec, w*Npoints );
225                delete Mpred;
226                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
227                Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
228
229                // This will be active only later in iterations!!!
230                if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
231                        // Generate new samples
232                        eSmp.set_samples ( Mpred );
233                        for ( int i = 0; i < Npoints; i++ ) {
234                                //////////// !!!!!!!!!!!!!
235                                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
236                                set_col_part ( Smp_ex, i, Smp ( i ) );
237                                //Importance of the mixture
238                                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
239                                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
240                        }
241                        if ( DBG ) {
242                                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
243                                cout << Mix.posterior().mean() << endl;
244                                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
245                                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
246                        }
247                }
248                if ( DBG ) {
249                        sprintf ( dbg_str, "Mpred_mean%d", niter );
250                        *dbg_file << Name ( dbg_str ) << Mpred->mean();
251                        sprintf ( dbg_str, "Mpred_var%d", niter );
252                        *dbg_file << Name ( dbg_str ) << Mpred->variance();
253                        sprintf ( dbg_str, "Mpred_cov%d", niter );
254                        *dbg_file << Name ( dbg_str ) << covariance();
255                       
256
257                        sprintf ( dbg_str, "pdf%d", niter );
258                        for ( int i = 0; i < Npoints; i++ ) {
259                                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );
260                        }
261                        *dbg_file << Name ( dbg_str ) << Mix_pdf;
262
263                        sprintf ( dbg_str, "Smp%d", niter );
264                        *dbg_file << Name ( dbg_str ) << Smp_ex;
265
266                }
267                //Importace weighting
268                for ( int i = 0; i < pdfs.length(); i++ ) {
269                        lw_src = 0.0;
270                        //======== Same RVs ===========
271                        //Split according to dependency in rvs
272                        if ( pdfs ( i )->dimension() == dim ) {
273                                // no need for conditioning or marginalization
274                                lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
275                        } else {
276                                // compute likelihood of marginal on the conditional variable
277                                if ( pdfs ( i )->dimensionc() > 0 ) {
278                                        // Make marginal on rvc_i
279                                        shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() );
280                                        //compute vector of lw_src
281                                        for ( int k = 0; k < Npoints; k++ ) {
282                                                // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
283                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
284                                        }
285
286//                                      sprintf ( str,"marg%d",niter );
287//                                      *dbg << Name ( str ) << lw_src;
288
289                                }
290                                // Compute likelihood of the missing variable
291                                if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) {
292                                        ///////////////
293                                        // There are variales unknown to pdfs(i) : rvzs
294                                        shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
295                                        // Compute likelihood
296                                        vec lw_dbg = lw_src;
297                                        for ( int k = 0; k < Npoints; k++ ) {
298                                                lw_src ( k ) += log (
299                                                                    tmp_cond->evallogcond (
300                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
301                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
302                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
303                                                        lw_src ( k ) = -1e16;
304                                                        cout << "!";
305                                                }
306                                        }
307                                }
308                                // Compute likelihood of the partial source
309                                for ( int k = 0; k < Npoints; k++ ) {
310                                        lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
311                                                        dls ( i )->get_cond ( Smp ( k ) ) );
312                                }
313
314                        }
315
316                        lW.set_row ( i, lw_src ); // do not divide by mix
317                }
318                lw = merger_base::merge_points ( lW ); //merge
319
320                //Importance weighting
321                lw -=  lw_mix; // hoping that it is not numerically sensitive...
322                w = exp ( lw - max ( lw ) );
323
324                //renormalize
325                double sumw = sum ( w );
326                if ( std::isfinite ( sumw ) ) {
327                        w = w / sumw;
328                } else {
329                        it_file itf ( "merg_err.it" );
330                        itf << Name ( "w" ) << w;
331                }
332
333                if ( DBG ) {
334                        sprintf ( dbg_str, "lW%d", niter );
335                        *dbg_file << Name ( dbg_str ) << lW;
336                        sprintf ( dbg_str, "w%d", niter );
337                        *dbg_file << Name ( dbg_str ) << w;
338                        sprintf ( dbg_str, "lw_m%d", niter );
339                        *dbg_file << Name ( dbg_str ) << lw_mix;
340                }
341                // ==== stopping rule ===
342                niter++;
343                converged = ( niter > stop_niter );
344        }
345        delete Mpred;
346//              cout << endl;
347
348}
349
350// DEFAULTS FOR MERGER_BASE
351const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
352const double merger_base::DFLT_beta = 1.2;
353// DEFAULTS FOR MERGER_MIX
354const int merger_mix::DFLT_Ncoms = 10;
355const double merger_mix::DFLT_effss_coef = 0.9;
356
357}
Note: See TracBrowser for help on using the browser.