root/library/bdm/stat/merger.cpp @ 1068

Revision 1068, 13.8 kB (checked in by mido, 15 years ago)

patch of documentation - all conditional pdfs revised

  • Property svn:eol-style set to native
RevLine 
[262]1
[176]2#include "merger.h"
[384]3#include "../estim/arx.h"
[176]4
[477]5namespace bdm {
[423]6
[737]7merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) :
[1064]8    Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9    set_sources ( S );
[477]10}
[423]11
[956]12
[739]13void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) {
[1064]14    pdfs = Sources;
15    Nsources = pdfs.length();
16    //set sizes
17    dls.set_size ( Sources.length() );
18    rvzs.set_size ( Sources.length() );
19    zdls.set_size ( Sources.length() );
[739]20
[1064]21    rv = get_composite_rv ( pdfs, /* checkoverlap = */ false );
[739]22
[1064]23    RV rvc;
24    // Extend rv by rvc!
25    for ( int i = 0; i < pdfs.length(); i++ ) {
26        RV rvx = pdfs ( i )->_rvc().subt ( rv );
27        rvc.add ( rvx ); // add rv to common rvc
28    }
[739]29
[1064]30    // join rv and rvc - see descriprion
31    rv.add ( rvc );
32    // get dimension
33    dim = rv._dsize();
[739]34
[1064]35    // create links between sources and common rv
36    RV xytmp;
37    for ( int i = 0; i < pdfs.length(); i++ ) {
38        //Establich connection between pdfs and merger
39        dls ( i ) = new datalink_m2e;
40        dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv );
[739]41
[1064]42        // find out what is missing in each pdf
43        xytmp = pdfs ( i )->_rv();
44        xytmp.add ( pdfs ( i )->_rvc() );
45        // z_i = common_rv-xy
46        rvzs ( i ) = rv.subt ( xytmp );
47        //establish connection between extension (z_i|x,y)s and common rv
48        zdls ( i ) = new datalink_m2e;
49        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
50    };
[739]51}
52
53void merger_base::set_support ( rectangular_support &Sup ) {
[1064]54    Npoints = Sup.points();
55    eSmp.set_parameters ( Npoints, false );
56    Array<vec> &samples = eSmp._samples();
57    eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
58    //set samples
59    samples ( 0 ) = Sup.first_vec();
60    for ( int j = 1; j < Npoints; j++ ) {
61        samples ( j ) = Sup.next_vec();
62    }
[739]63}
64
65void merger_base::merge () {
[1064]66    validate();
[739]67
[1064]68    //check if sources overlap:
69    bool OK = true;
70    for ( int i = 0; i < pdfs.length(); i++ ) {
71        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
72        OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
73    }
[739]74
[1064]75    if ( OK ) {
76        mat lW = zeros ( pdfs.length(), eSmp._w().length() );
[739]77
[1064]78        vec emptyvec ( 0 );
79        for ( int i = 0; i < pdfs.length(); i++ ) {
80            for ( int j = 0; j < eSmp._w().length(); j++ ) {
81                lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
82            }
83        }
[739]84
[1064]85        vec w_nn = merge_points ( lW );
86        vec wtmp = exp ( w_nn - max ( w_nn ) );
87        //renormalize
88        eSmp._w() = wtmp / sum ( wtmp );
89    } else {
90        bdm_error ( "Sources are not compatible - use merger_mix" );
91    }
[739]92}
93
[477]94vec merger_base::merge_points ( mat &lW ) {
[1064]95    int nu = lW.rows();
96    vec result;
97    ivec indW;
98    bool infexist = false;
99    switch ( METHOD ) {
100    case ARITHMETIC:
101        result = log ( sum ( exp ( lW ) ) ); //ugly!
102        break;
103    case GEOMETRIC:
104        result = sum ( lW ) / nu;
105        break;
106    case LOGNORMAL:
107        vec sumlW = sum ( lW ) ;
108        indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
109        infexist = ( indW.size() < lW.cols() );
110        vec mu;
111        vec lam;
112        if ( infexist ) {
113            mu = sumlW ( indW ) / nu; //mean of logs
114            //
115            mat validlW = lW.get_cols ( indW );
116            lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
117        } else {
118            mu = sum ( lW ) / nu; //mean of logs
119            lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
120        }
121        //
122        double coef = 0.0;
123        vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
124        switch ( nu ) {
125        case 2:
126            coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
127            result = coef * sq2bl + mu ;
128            break;
129            // case 4: == can be done similar to case 2 - is it worth it???
130        default: // see accompanying document merge_lognorm_derivation.lyx
131            coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
132            result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
133            break;
134        }
135        break;
136    }
137    if ( infexist ) {
138        vec tmp = -inf * ones ( lW.cols() );
139        set_subvector ( tmp, indW, result );
140        return tmp;
141    } else {
142        return result;
143    }
[477]144}
[176]145
[739]146vec merger_base::mean() const {
[1064]147    const Vec<double> &w = eSmp._w();
148    const Array<vec> &S = eSmp._samples();
149    vec tmp = zeros ( dim );
150    for ( int i = 0; i < Npoints; i++ ) {
151        tmp += w ( i ) * S ( i );
152    }
153    return tmp;
[739]154}
155
156mat merger_base::covariance() const {
[1064]157    const vec &w = eSmp._w();
158    const Array<vec> &S = eSmp._samples();
[739]159
[1064]160    vec mea = mean();
[739]161
[1068]162//             cout << sum (w) << "," << w*w << endl;
[739]163
[1064]164    mat Tmp = zeros ( dim, dim );
165    for ( int i = 0; i < Npoints; i++ ) {
166        vec tmp=S ( i )-mea; //inefficient but numerically stable
167        Tmp += w ( i ) * outer_product (tmp , tmp );
168    }
169    return Tmp;
[739]170}
171
172vec merger_base::variance() const {
[1064]173    return eSmp.variance();
[739]174}
175
[1068]176
177void merger_base::from_setting ( const Setting& set ) {
178    // get support
179    // find which method to use
180    epdf::from_setting (set);
181    string meth_str;
182    UI::get( meth_str, set, "method", UI::compulsory );
183    if ( meth_str == "arithmetic" )
184        set_method ( ARITHMETIC );
185    else if ( meth_str == "geometric" )
186        set_method ( GEOMETRIC );
187    else if ( meth_str ==  "lognormal" ) {
188        set_method ( LOGNORMAL );
189        UI::get(beta, set, "beta", UI::compulsory );
190    }
191
192
193    string dbg_filename;
194    if ( UI::get ( dbg_filename, set, "dbg_file" ) )
195        set_debug_file( dbg_filename );
196
197}
198
199void merger_base::to_setting  (Setting  &set) const {
200    epdf::to_setting(set);
201
202    UI::save( METHOD, set, "method");
203
204    if( METHOD == LOGNORMAL )
205        UI::save (beta, set, "beta" );
206
207    if( DBG )
208        UI::save ( dbg_file->get_fname(), set, "dbg_file" );
209}
210
211void merger_base::validate() {
212//        bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
213//        bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
214    epdf::validate();
215    bdm_assert ( isnamed(), "mergers must be named" );
216}
217
218// DEFAULTS FOR MERGER_BASE
219const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
220const double merger_base::DFLT_beta = 1.2;
221
[477]222void merger_mix::merge ( ) {
[1064]223    if(Npoints<1) {
224        set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000);
225    }
[180]226
[1064]227    bdm_assert(Npoints>0,"No points in support");
228    bdm_assert(Nsources>0,"No Sources");
[180]229
[1064]230    Array<vec> &Smp = eSmp._samples(); //aux
231    vec &w = eSmp._w(); //aux
[180]232
[1064]233    mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
234    for ( int i = 0; i < Npoints; i++ ) {
235        set_col_part ( Smp_ex, i, Smp ( i ) );
236    }
[176]237
[1068]238    if ( DBG )    *dbg_file << Name ( "Smp_0" ) << Smp_ex;
[300]239
[1064]240    // Stuff for merging
[1068]241    vec lw_src ( Npoints );        // weights of the ith source
242    vec lw_mix ( Npoints );        // weights of the approximating mixture
243    vec lw ( Npoints );            // tmp
[1064]244    mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
245    vec vec0 ( 0 );
[176]246
[1064]247    //initialize importance weights
248    lw_mix = 1.0; // assuming uniform grid density -- otherwise
[197]249
[1064]250    // Initial component in the mixture model
251    mat V0 = 1e-8 * eye ( dim + 1 );
252    ARX A0;
253    A0.set_statistics ( dim, V0 ); //initial guess of Mix:
254    A0.validate();
[182]255
[1064]256    Mix.init ( &A0, Smp_ex, Ncoms );
257    //Preserve initial mixture for repetitive estimation via flattening
258    MixEF Mix_init ( Mix );
[213]259
[1064]260    // ============= MAIN LOOP ==================
261    bool converged = false;
262    int niter = 0;
263    char dbg_str[100];
[197]264
[1064]265    emix* Mpred = Mix.epredictor ( );
266    vec Mix_pdf ( Npoints );
267    while ( !converged ) {
268        //Re-estimate Mix
269        //Re-Initialize Mixture model
270        Mix.flatten ( &Mix_init , 1.0);
271        Mix.bayes_batch_weighted ( Smp_ex, empty_vec, w*Npoints );
272        delete Mpred;
273        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
274        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
[180]275
[1064]276        // This will be active only later in iterations!!!
277        if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
278            // Generate new samples
279            eSmp.set_samples ( Mpred );
280            for ( int i = 0; i < Npoints; i++ ) {
281                //////////// !!!!!!!!!!!!!
282                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
283                set_col_part ( Smp_ex, i, Smp ( i ) );
284                //Importance of the mixture
285                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
286                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
287            }
288            if ( DBG ) {
289                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
290                cout << Mix.posterior().mean() << endl;
291                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
292                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
293            }
294        }
295        if ( DBG ) {
296            sprintf ( dbg_str, "Mpred_mean%d", niter );
297            *dbg_file << Name ( dbg_str ) << Mpred->mean();
298            sprintf ( dbg_str, "Mpred_var%d", niter );
299            *dbg_file << Name ( dbg_str ) << Mpred->variance();
300            sprintf ( dbg_str, "Mpred_cov%d", niter );
301            *dbg_file << Name ( dbg_str ) << covariance();
[180]302
[198]303
[1064]304            sprintf ( dbg_str, "pdf%d", niter );
305            for ( int i = 0; i < Npoints; i++ ) {
306                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ), empty_vec );
307            }
308            *dbg_file << Name ( dbg_str ) << Mix_pdf;
309
310            sprintf ( dbg_str, "Smp%d", niter );
311            *dbg_file << Name ( dbg_str ) << Smp_ex;
312
313        }
314        //Importace weighting
315        for ( int i = 0; i < pdfs.length(); i++ ) {
316            lw_src = 0.0;
317            //======== Same RVs ===========
318            //Split according to dependency in rvs
319            if ( pdfs ( i )->dimension() == dim ) {
320                // no need for conditioning or marginalization
321                lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
322            } else {
323                // compute likelihood of marginal on the conditional variable
324                if ( pdfs ( i )->dimensionc() > 0 ) {
325                    // Make marginal on rvc_i
326                    shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() );
327                    //compute vector of lw_src
328                    for ( int k = 0; k < Npoints; k++ ) {
329                        // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
330                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
331                    }
332
[1068]333//                     sprintf ( str,"marg%d",niter );
334//                     *dbg << Name ( str ) << lw_src;
[198]335
[1064]336                }
337                // Compute likelihood of the missing variable
338                if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) {
339                    ///////////////
340                    // There are variales unknown to pdfs(i) : rvzs
341                    shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
342                    // Compute likelihood
343                    vec lw_dbg = lw_src;
344                    for ( int k = 0; k < Npoints; k++ ) {
345                        lw_src ( k ) += log (
346                                            tmp_cond->evallogcond (
347                                                zdls ( i )->pushdown ( Smp ( k ) ),
348                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
349                        if ( !std::isfinite ( lw_src ( k ) ) ) {
350                            lw_src ( k ) = -1e16;
351                            cout << "!";
352                        }
353                    }
354                }
355                // Compute likelihood of the partial source
356                for ( int k = 0; k < Npoints; k++ ) {
357                    lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
358                                    dls ( i )->get_cond ( Smp ( k ) ) );
359                }
[477]360
[1064]361            }
[565]362
[1064]363            lW.set_row ( i, lw_src ); // do not divide by mix
364        }
365        lw = merger_base::merge_points ( lW ); //merge
[197]366
[1064]367        //Importance weighting
368        lw -=  lw_mix; // hoping that it is not numerically sensitive...
369        w = exp ( lw - max ( lw ) );
[300]370
[1064]371        //renormalize
372        double sumw = sum ( w );
373        if ( std::isfinite ( sumw ) ) {
374            w = w / sumw;
375        } else {
376            it_file itf ( "merg_err.it" );
377            itf << Name ( "w" ) << w;
378        }
[180]379
[1064]380        if ( DBG ) {
381            sprintf ( dbg_str, "lW%d", niter );
382            *dbg_file << Name ( dbg_str ) << lW;
383            sprintf ( dbg_str, "w%d", niter );
384            *dbg_file << Name ( dbg_str ) << w;
385            sprintf ( dbg_str, "lw_m%d", niter );
386            *dbg_file << Name ( dbg_str ) << lw_mix;
387        }
388        // ==== stopping rule ===
389        niter++;
390        converged = ( niter > stop_niter );
391    }
392    delete Mpred;
[1068]393//        cout << endl;
[205]394
[477]395}
[176]396
[956]397void merger_mix::from_setting ( const Setting& set ) {
[1064]398    merger_base::from_setting ( set );
399    Ncoms=DFLT_Ncoms;
400    UI::get( Ncoms, set, "ncoms", UI::optional );
401    effss_coef=DFLT_effss_coef;
402    UI::get (effss_coef , set,  "effss_coef", UI::optional);
403    stop_niter=10;
404    UI::get ( stop_niter, set,"stop_niter", UI::optional );
[957]405}
[956]406
[1068]407void merger_mix::to_setting  (Setting  &set) const  {
[1064]408    merger_base::to_setting(set);
409    UI::save( Ncoms, set, "ncoms");
410    UI::save (effss_coef , set,  "effss_coef");
411    UI::save ( stop_niter, set,"stop_niter");
[957]412}
[956]413
414void merger_mix::validate() {
[1064]415    merger_base::validate();
416    bdm_assert(Ncoms>0,"Ncoms too small");
[957]417}
[956]418
[477]419// DEFAULTS FOR MERGER_MIX
420const int merger_mix::DFLT_Ncoms = 10;
[773]421const double merger_mix::DFLT_effss_coef = 0.9;
[399]422
[176]423}
Note: See TracBrowser for help on using the browser.