root/library/bdm/stat/merger.cpp @ 1077

Revision 1068, 13.8 kB (checked in by mido, 14 years ago)

patch of documentation - all conditional pdfs revised

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) :
8    Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9    set_sources ( S );
10}
11
12
13void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) {
14    pdfs = Sources;
15    Nsources = pdfs.length();
16    //set sizes
17    dls.set_size ( Sources.length() );
18    rvzs.set_size ( Sources.length() );
19    zdls.set_size ( Sources.length() );
20
21    rv = get_composite_rv ( pdfs, /* checkoverlap = */ false );
22
23    RV rvc;
24    // Extend rv by rvc!
25    for ( int i = 0; i < pdfs.length(); i++ ) {
26        RV rvx = pdfs ( i )->_rvc().subt ( rv );
27        rvc.add ( rvx ); // add rv to common rvc
28    }
29
30    // join rv and rvc - see descriprion
31    rv.add ( rvc );
32    // get dimension
33    dim = rv._dsize();
34
35    // create links between sources and common rv
36    RV xytmp;
37    for ( int i = 0; i < pdfs.length(); i++ ) {
38        //Establich connection between pdfs and merger
39        dls ( i ) = new datalink_m2e;
40        dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv );
41
42        // find out what is missing in each pdf
43        xytmp = pdfs ( i )->_rv();
44        xytmp.add ( pdfs ( i )->_rvc() );
45        // z_i = common_rv-xy
46        rvzs ( i ) = rv.subt ( xytmp );
47        //establish connection between extension (z_i|x,y)s and common rv
48        zdls ( i ) = new datalink_m2e;
49        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
50    };
51}
52
53void merger_base::set_support ( rectangular_support &Sup ) {
54    Npoints = Sup.points();
55    eSmp.set_parameters ( Npoints, false );
56    Array<vec> &samples = eSmp._samples();
57    eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
58    //set samples
59    samples ( 0 ) = Sup.first_vec();
60    for ( int j = 1; j < Npoints; j++ ) {
61        samples ( j ) = Sup.next_vec();
62    }
63}
64
65void merger_base::merge () {
66    validate();
67
68    //check if sources overlap:
69    bool OK = true;
70    for ( int i = 0; i < pdfs.length(); i++ ) {
71        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
72        OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
73    }
74
75    if ( OK ) {
76        mat lW = zeros ( pdfs.length(), eSmp._w().length() );
77
78        vec emptyvec ( 0 );
79        for ( int i = 0; i < pdfs.length(); i++ ) {
80            for ( int j = 0; j < eSmp._w().length(); j++ ) {
81                lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
82            }
83        }
84
85        vec w_nn = merge_points ( lW );
86        vec wtmp = exp ( w_nn - max ( w_nn ) );
87        //renormalize
88        eSmp._w() = wtmp / sum ( wtmp );
89    } else {
90        bdm_error ( "Sources are not compatible - use merger_mix" );
91    }
92}
93
94vec merger_base::merge_points ( mat &lW ) {
95    int nu = lW.rows();
96    vec result;
97    ivec indW;
98    bool infexist = false;
99    switch ( METHOD ) {
100    case ARITHMETIC:
101        result = log ( sum ( exp ( lW ) ) ); //ugly!
102        break;
103    case GEOMETRIC:
104        result = sum ( lW ) / nu;
105        break;
106    case LOGNORMAL:
107        vec sumlW = sum ( lW ) ;
108        indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
109        infexist = ( indW.size() < lW.cols() );
110        vec mu;
111        vec lam;
112        if ( infexist ) {
113            mu = sumlW ( indW ) / nu; //mean of logs
114            //
115            mat validlW = lW.get_cols ( indW );
116            lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
117        } else {
118            mu = sum ( lW ) / nu; //mean of logs
119            lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
120        }
121        //
122        double coef = 0.0;
123        vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
124        switch ( nu ) {
125        case 2:
126            coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
127            result = coef * sq2bl + mu ;
128            break;
129            // case 4: == can be done similar to case 2 - is it worth it???
130        default: // see accompanying document merge_lognorm_derivation.lyx
131            coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
132            result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
133            break;
134        }
135        break;
136    }
137    if ( infexist ) {
138        vec tmp = -inf * ones ( lW.cols() );
139        set_subvector ( tmp, indW, result );
140        return tmp;
141    } else {
142        return result;
143    }
144}
145
146vec merger_base::mean() const {
147    const Vec<double> &w = eSmp._w();
148    const Array<vec> &S = eSmp._samples();
149    vec tmp = zeros ( dim );
150    for ( int i = 0; i < Npoints; i++ ) {
151        tmp += w ( i ) * S ( i );
152    }
153    return tmp;
154}
155
156mat merger_base::covariance() const {
157    const vec &w = eSmp._w();
158    const Array<vec> &S = eSmp._samples();
159
160    vec mea = mean();
161
162//             cout << sum (w) << "," << w*w << endl;
163
164    mat Tmp = zeros ( dim, dim );
165    for ( int i = 0; i < Npoints; i++ ) {
166        vec tmp=S ( i )-mea; //inefficient but numerically stable
167        Tmp += w ( i ) * outer_product (tmp , tmp );
168    }
169    return Tmp;
170}
171
172vec merger_base::variance() const {
173    return eSmp.variance();
174}
175
176
177void merger_base::from_setting ( const Setting& set ) {
178    // get support
179    // find which method to use
180    epdf::from_setting (set);
181    string meth_str;
182    UI::get( meth_str, set, "method", UI::compulsory );
183    if ( meth_str == "arithmetic" )
184        set_method ( ARITHMETIC );
185    else if ( meth_str == "geometric" )
186        set_method ( GEOMETRIC );
187    else if ( meth_str ==  "lognormal" ) {
188        set_method ( LOGNORMAL );
189        UI::get(beta, set, "beta", UI::compulsory );
190    }
191
192
193    string dbg_filename;
194    if ( UI::get ( dbg_filename, set, "dbg_file" ) )
195        set_debug_file( dbg_filename );
196
197}
198
199void merger_base::to_setting  (Setting  &set) const {
200    epdf::to_setting(set);
201
202    UI::save( METHOD, set, "method");
203
204    if( METHOD == LOGNORMAL )
205        UI::save (beta, set, "beta" );
206
207    if( DBG )
208        UI::save ( dbg_file->get_fname(), set, "dbg_file" );
209}
210
211void merger_base::validate() {
212//        bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
213//        bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
214    epdf::validate();
215    bdm_assert ( isnamed(), "mergers must be named" );
216}
217
218// DEFAULTS FOR MERGER_BASE
219const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
220const double merger_base::DFLT_beta = 1.2;
221
222void merger_mix::merge ( ) {
223    if(Npoints<1) {
224        set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000);
225    }
226
227    bdm_assert(Npoints>0,"No points in support");
228    bdm_assert(Nsources>0,"No Sources");
229
230    Array<vec> &Smp = eSmp._samples(); //aux
231    vec &w = eSmp._w(); //aux
232
233    mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
234    for ( int i = 0; i < Npoints; i++ ) {
235        set_col_part ( Smp_ex, i, Smp ( i ) );
236    }
237
238    if ( DBG )    *dbg_file << Name ( "Smp_0" ) << Smp_ex;
239
240    // Stuff for merging
241    vec lw_src ( Npoints );        // weights of the ith source
242    vec lw_mix ( Npoints );        // weights of the approximating mixture
243    vec lw ( Npoints );            // tmp
244    mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
245    vec vec0 ( 0 );
246
247    //initialize importance weights
248    lw_mix = 1.0; // assuming uniform grid density -- otherwise
249
250    // Initial component in the mixture model
251    mat V0 = 1e-8 * eye ( dim + 1 );
252    ARX A0;
253    A0.set_statistics ( dim, V0 ); //initial guess of Mix:
254    A0.validate();
255
256    Mix.init ( &A0, Smp_ex, Ncoms );
257    //Preserve initial mixture for repetitive estimation via flattening
258    MixEF Mix_init ( Mix );
259
260    // ============= MAIN LOOP ==================
261    bool converged = false;
262    int niter = 0;
263    char dbg_str[100];
264
265    emix* Mpred = Mix.epredictor ( );
266    vec Mix_pdf ( Npoints );
267    while ( !converged ) {
268        //Re-estimate Mix
269        //Re-Initialize Mixture model
270        Mix.flatten ( &Mix_init , 1.0);
271        Mix.bayes_batch_weighted ( Smp_ex, empty_vec, w*Npoints );
272        delete Mpred;
273        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
274        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
275
276        // This will be active only later in iterations!!!
277        if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
278            // Generate new samples
279            eSmp.set_samples ( Mpred );
280            for ( int i = 0; i < Npoints; i++ ) {
281                //////////// !!!!!!!!!!!!!
282                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
283                set_col_part ( Smp_ex, i, Smp ( i ) );
284                //Importance of the mixture
285                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
286                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
287            }
288            if ( DBG ) {
289                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
290                cout << Mix.posterior().mean() << endl;
291                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
292                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
293            }
294        }
295        if ( DBG ) {
296            sprintf ( dbg_str, "Mpred_mean%d", niter );
297            *dbg_file << Name ( dbg_str ) << Mpred->mean();
298            sprintf ( dbg_str, "Mpred_var%d", niter );
299            *dbg_file << Name ( dbg_str ) << Mpred->variance();
300            sprintf ( dbg_str, "Mpred_cov%d", niter );
301            *dbg_file << Name ( dbg_str ) << covariance();
302
303
304            sprintf ( dbg_str, "pdf%d", niter );
305            for ( int i = 0; i < Npoints; i++ ) {
306                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ), empty_vec );
307            }
308            *dbg_file << Name ( dbg_str ) << Mix_pdf;
309
310            sprintf ( dbg_str, "Smp%d", niter );
311            *dbg_file << Name ( dbg_str ) << Smp_ex;
312
313        }
314        //Importace weighting
315        for ( int i = 0; i < pdfs.length(); i++ ) {
316            lw_src = 0.0;
317            //======== Same RVs ===========
318            //Split according to dependency in rvs
319            if ( pdfs ( i )->dimension() == dim ) {
320                // no need for conditioning or marginalization
321                lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
322            } else {
323                // compute likelihood of marginal on the conditional variable
324                if ( pdfs ( i )->dimensionc() > 0 ) {
325                    // Make marginal on rvc_i
326                    shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() );
327                    //compute vector of lw_src
328                    for ( int k = 0; k < Npoints; k++ ) {
329                        // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
330                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
331                    }
332
333//                     sprintf ( str,"marg%d",niter );
334//                     *dbg << Name ( str ) << lw_src;
335
336                }
337                // Compute likelihood of the missing variable
338                if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) {
339                    ///////////////
340                    // There are variales unknown to pdfs(i) : rvzs
341                    shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
342                    // Compute likelihood
343                    vec lw_dbg = lw_src;
344                    for ( int k = 0; k < Npoints; k++ ) {
345                        lw_src ( k ) += log (
346                                            tmp_cond->evallogcond (
347                                                zdls ( i )->pushdown ( Smp ( k ) ),
348                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
349                        if ( !std::isfinite ( lw_src ( k ) ) ) {
350                            lw_src ( k ) = -1e16;
351                            cout << "!";
352                        }
353                    }
354                }
355                // Compute likelihood of the partial source
356                for ( int k = 0; k < Npoints; k++ ) {
357                    lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
358                                    dls ( i )->get_cond ( Smp ( k ) ) );
359                }
360
361            }
362
363            lW.set_row ( i, lw_src ); // do not divide by mix
364        }
365        lw = merger_base::merge_points ( lW ); //merge
366
367        //Importance weighting
368        lw -=  lw_mix; // hoping that it is not numerically sensitive...
369        w = exp ( lw - max ( lw ) );
370
371        //renormalize
372        double sumw = sum ( w );
373        if ( std::isfinite ( sumw ) ) {
374            w = w / sumw;
375        } else {
376            it_file itf ( "merg_err.it" );
377            itf << Name ( "w" ) << w;
378        }
379
380        if ( DBG ) {
381            sprintf ( dbg_str, "lW%d", niter );
382            *dbg_file << Name ( dbg_str ) << lW;
383            sprintf ( dbg_str, "w%d", niter );
384            *dbg_file << Name ( dbg_str ) << w;
385            sprintf ( dbg_str, "lw_m%d", niter );
386            *dbg_file << Name ( dbg_str ) << lw_mix;
387        }
388        // ==== stopping rule ===
389        niter++;
390        converged = ( niter > stop_niter );
391    }
392    delete Mpred;
393//        cout << endl;
394
395}
396
397void merger_mix::from_setting ( const Setting& set ) {
398    merger_base::from_setting ( set );
399    Ncoms=DFLT_Ncoms;
400    UI::get( Ncoms, set, "ncoms", UI::optional );
401    effss_coef=DFLT_effss_coef;
402    UI::get (effss_coef , set,  "effss_coef", UI::optional);
403    stop_niter=10;
404    UI::get ( stop_niter, set,"stop_niter", UI::optional );
405}
406
407void merger_mix::to_setting  (Setting  &set) const  {
408    merger_base::to_setting(set);
409    UI::save( Ncoms, set, "ncoms");
410    UI::save (effss_coef , set,  "effss_coef");
411    UI::save ( stop_niter, set,"stop_niter");
412}
413
414void merger_mix::validate() {
415    merger_base::validate();
416    bdm_assert(Ncoms>0,"Ncoms too small");
417}
418
419// DEFAULTS FOR MERGER_MIX
420const int merger_mix::DFLT_Ncoms = 10;
421const double merger_mix::DFLT_effss_coef = 0.9;
422
423}
Note: See TracBrowser for help on using the browser.