root/library/bdm/stat/merger.cpp

Revision 1410, 13.2 kB (checked in by smidl, 13 years ago)

oprava pro novou verzi itpp

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7MergerDiscrete::MergerDiscrete ( const Array<shared_ptr<pdf> > &S ) :
8    Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9    set_sources ( S );
10}
11
12
13void MergerDiscrete::set_sources ( const Array<shared_ptr<pdf> > &PartSources ) {
14    part_sources = PartSources;
15    Nsources = part_sources.length();
16    //set sizes
17    sources.set_size ( Nsources);
18        dls.set_size ( Nsources );
19        rvzs.set_size ( Nsources );
20        zdls.set_size ( Nsources );
21
22    RV rv = get_composite_rv ( part_sources, /* checkoverlap = */ false );
23
24    RV rvc;
25    // Extend rv by rvc!
26    for ( int i = 0; i < sources.length(); i++ ) {
27        RV rvx = part_sources ( i )->_rvc().subt ( rv );
28        rvc.add ( rvx ); // add rv to common rvc
29    }
30
31    // join rv and rvc - see descriprion
32    rv.add ( rvc );
33       
34    // create links between sources and common rv
35    RV xytmp;
36    for ( int i = 0; i < part_sources.length(); i++ ) {         
37        //Establich connection between pdfs and merger
38        dls ( i ) = new datalink_m2e;
39                dls ( i )->set_connection ( part_sources ( i )->_rv(), part_sources( i )->_rvc(), rv );
40
41        // find out what is missing in each pdf
42        xytmp = part_sources ( i )->_rv();
43                xytmp.add ( part_sources ( i )->_rvc() );
44        // z_i = common_rv-xy
45        rvzs ( i ) = rv.subt ( xytmp );
46        //establish connection between extension (z_i|x,y)s and common rv
47        zdls ( i ) = new datalink_m2e;
48        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
49    };
50       
51        //
52        merger().set_rv(rv);
53}
54
55void MergerDiscrete::set_support ( rectangular_support &Sup ) {
56    Npoints = Sup.points();
57    eSmp.set_parameters ( Npoints, false );
58    Array<vec> &samples = eSmp._samples();
59    eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
60    //set samples
61    samples ( 0 ) = Sup.first_vec();
62    for ( int j = 1; j < Npoints; j++ ) {
63        samples ( j ) = Sup.next_vec();
64    }
65    eSmp.validate();
66}
67
68void MergerDiscrete::merge () {
69    validate();
70
71    //check if sources overlap:
72    bool OK = true;
73    for ( int i = 0; i < part_sources.length(); i++ ) {
74        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
75        OK &= ( part_sources ( i )->_rvc()._dsize() == 0 ); // y_i is empty
76    }
77
78    if ( OK ) {
79        mat lW = zeros ( part_sources.length(), eSmp._w().length() );
80
81        vec emptyvec ( 0 );
82        for ( int i = 0; i < part_sources.length(); i++ ) {
83            for ( int j = 0; j < eSmp._w().length(); j++ ) {
84                lW ( i, j ) = part_sources ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
85            }
86        }
87
88        vec w_nn = merge_points ( lW );
89        vec wtmp = exp ( w_nn - max ( w_nn ) );
90        //renormalize
91        eSmp._w() = wtmp / sum ( wtmp );
92    } else {
93        bdm_error ( "Sources are not compatible - use merger_mix" );
94    }
95}
96
97vec MergerDiscrete::merge_points ( mat &lW ) {
98    int nu = lW.rows();
99    vec result;
100    ivec indW;
101    bool infexist = false;
102    switch ( METHOD ) {
103    case ARITHMETIC:
104        result = log ( sum ( exp ( lW ) ) ); //ugly!
105        break;
106    case GEOMETRIC:
107        result = sum ( lW ) / nu;
108        break;
109    case LOGNORMAL:
110        vec sumlW = sum ( lW ) ;
111        indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
112        infexist = ( indW.size() < lW.cols() );
113        vec mu;
114        vec lam;
115        if ( infexist ) {
116            mu = sumlW ( indW ) / nu; //mean of logs
117            //
118            mat validlW = lW.get_cols ( indW );
119            lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
120        } else {
121            mu = sum ( lW ) / nu; //mean of logs
122            lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
123        }
124        //
125        double coef = 0.0;
126        vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
127        switch ( nu ) {
128        case 2:
129            coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
130            result = coef * sq2bl + mu ;
131            break;
132            // case 4: == can be done similar to case 2 - is it worth it???
133        default: // see accompanying document merge_lognorm_derivation.lyx
134            coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
135            result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
136            break;
137        }
138        break;
139    }
140    if ( infexist ) {
141        vec tmp = -inf * ones ( lW.cols() );
142        set_subvector ( tmp, indW, result );
143        return tmp;
144    } else {
145        return result;
146    }
147}
148
149void MergerDiscrete::from_setting ( const Setting& set ) {
150    // get support
151    // find which method to use
152    string meth_str;
153    UI::get( meth_str, set, "method", UI::compulsory );
154    if ( meth_str == "arithmetic" )
155        set_method ( ARITHMETIC );
156    else if ( meth_str == "geometric" )
157        set_method ( GEOMETRIC );
158    else if ( meth_str ==  "lognormal" ) {
159        set_method ( LOGNORMAL );
160        UI::get(beta, set, "beta", UI::compulsory );
161    }
162
163
164    string dbg_filename;
165    if ( UI::get ( dbg_filename, set, "dbg_file" ) )
166        set_debug_file( dbg_filename );
167
168}
169
170void MergerDiscrete::to_setting  (Setting  &set) const {
171    UI::save( METHOD, set, "method");
172
173    if( METHOD == LOGNORMAL )
174        UI::save (beta, set, "beta" );
175
176/*    if( DBG )
177        UI::save ( dbg_file->fname(), set, "dbg_file" );*/
178}
179
180void MergerDiscrete::validate() {
181//        bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
182//        bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
183    MergerBase::validate();
184    //bdm_assert ( merger().isnamed(), "mergers must be named" );
185}
186
187void merger_mix::merge ( ) {
188        int dim = eSmp.dimension();
189    if(Npoints<1) {
190        set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000);
191    }
192
193    bdm_assert(Npoints>0,"No points in support");
194    bdm_assert(Nsources>0,"No Sources");
195
196    Array<vec> &Smp = eSmp._samples(); //aux
197    vec &w = eSmp._w(); //aux
198
199    mat Smp_ex = ones ( dim, Npoints ); // Extended samples for the ARX model
200    for ( int i = 0; i < Npoints; i++ ) {
201        set_col_part ( Smp_ex, i, Smp ( i ) );
202    }
203
204    if ( DBG )    *dbg_file << Name ( "Smp_0" ) << Smp_ex;
205
206    // Stuff for merging
207    vec lw_src ( Npoints );        // weights of the ith source
208    vec lw_mix ( Npoints );        // weights of the approximating mixture
209    vec lw ( Npoints );            // tmp
210    mat lW = zeros ( Nsources, Npoints ); // array of weights of all part_sources
211    vec vec0 ( 0 );
212
213    //initialize importance weights
214    lw_mix = 1.0; // assuming uniform grid density -- otherwise
215
216    // Initial component in the mixture model
217    mat V0 = 1e-8 * eye ( dim + 1 );
218    ARX A0;
219    A0.set_statistics ( dim, V0 ); //initial guess of Mix:
220    A0.validate();
221
222    Mix.init ( &A0, Smp_ex, Ncoms );
223    //Preserve initial mixture for repetitive estimation via flattening
224    MixEF Mix_init ( Mix );
225
226    // ============= MAIN LOOP ==================
227    bool converged = false;
228    int niter = 0;
229    char dbg_str[100];
230
231        emix *Mpred = Mix.epredictor ( );
232    vec Mix_pdf ( Npoints );
233    while ( !converged ) {
234        //Re-estimate Mix
235        //Re-Initialize Mixture model
236        Mix.flatten ( &Mix_init , 1.0);
237        Mix.bayes_batch_weighted ( Smp_ex, empty_vec, w*Npoints );
238        delete Mpred;
239        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
240        Mpred->set_rv ( merger()._rv() ); //the predictor predicts rv of this merger
241
242        // This will be active only later in iterations!!!
243        if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
244            // Generate new samples
245            eSmp.set_samples ( Mpred );
246            for ( int i = 0; i < Npoints; i++ ) {
247                //////////// !!!!!!!!!!!!!
248                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
249                set_col_part ( Smp_ex, i, Smp ( i ) );
250                //Importance of the mixture
251                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
252                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
253            }
254            if ( DBG ) {
255                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
256                cout << Mix.posterior().mean() << endl;
257                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
258                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
259            }
260        }
261        if ( DBG ) {
262            sprintf ( dbg_str, "Mpred_mean%d", niter );
263            *dbg_file << Name ( dbg_str ) << Mpred->mean();
264            sprintf ( dbg_str, "Mpred_var%d", niter );
265            *dbg_file << Name ( dbg_str ) << Mpred->variance();
266            sprintf ( dbg_str, "Mpred_cov%d", niter );
267            *dbg_file << Name ( dbg_str ) << Mpred->covariance();
268
269
270            sprintf ( dbg_str, "pdf%d", niter );
271            for ( int i = 0; i < Npoints; i++ ) {
272                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ), empty_vec );
273            }
274            *dbg_file << Name ( dbg_str ) << Mix_pdf;
275
276            sprintf ( dbg_str, "Smp%d", niter );
277            *dbg_file << Name ( dbg_str ) << Smp_ex;
278
279        }
280        //Importace weighting
281        for ( int i = 0; i < part_sources.length(); i++ ) {
282            lw_src = 0.0;
283            //======== Same RVs ===========
284            //Split according to dependency in rvs
285            if ( part_sources ( i )->dimension() == dim ) {
286                // no need for conditioning or marginalization
287                lw_src = part_sources ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
288            } else {
289                // compute likelihood of marginal on the conditional variable
290                if ( part_sources ( i )->dimensionc() > 0 ) {
291                    // Make marginal on rvc_i
292                    shared_ptr<epdf> tmp_marg = Mpred->marginal ( part_sources ( i )->_rvc() );
293                    //compute vector of lw_src
294                    for ( int k = 0; k < Npoints; k++ ) {
295                        // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
296                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
297                    }
298
299//                     sprintf ( str,"marg%d",niter );
300//                     *dbg << Name ( str ) << lw_src;
301
302                }
303                // Compute likelihood of the missing variable
304                if ( dim > ( part_sources ( i )->dimension() + part_sources ( i )->dimensionc() ) ) {
305                    ///////////////
306                    // There are variales unknown to pdfs(i) : rvzs
307                    shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
308                    // Compute likelihood
309                    vec lw_dbg = lw_src;
310                    for ( int k = 0; k < Npoints; k++ ) {
311                        lw_src ( k ) += log (
312                                            tmp_cond->evallogcond (
313                                                zdls ( i )->pushdown ( Smp ( k ) ),
314                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
315                        if ( !std::isfinite ( lw_src ( k ) ) ) {
316                            lw_src ( k ) = -1e16;
317                            cout << "!";
318                        }
319                    }
320                }
321                // Compute likelihood of the partial source
322                for ( int k = 0; k < Npoints; k++ ) {
323                    lw_src ( k ) += part_sources ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
324                                    dls ( i )->get_cond ( Smp ( k ) ) );
325                }
326
327            }
328
329            lW.set_row ( i, lw_src ); // do not divide by mix
330        }
331        lw = MergerDiscrete::merge_points ( lW ); //merge
332
333        //Importance weighting
334        lw -=  lw_mix; // hoping that it is not numerically sensitive...
335        w = exp ( lw - max ( lw ) );
336
337        //renormalize
338        double sumw = sum ( w );
339        if ( std::isfinite ( sumw ) ) {
340            w = w / sumw;
341        } else {
342            it_file itf ( "merg_err.it" );
343            itf << Name ( "w" ) << w;
344        }
345
346        if ( DBG ) {
347            sprintf ( dbg_str, "lW%d", niter );
348            *dbg_file << Name ( dbg_str ) << lW;
349            sprintf ( dbg_str, "w%d", niter );
350            *dbg_file << Name ( dbg_str ) << w;
351            sprintf ( dbg_str, "lw_m%d", niter );
352            *dbg_file << Name ( dbg_str ) << lw_mix;
353        }
354        // ==== stopping rule ===
355        niter++;
356        converged = ( niter > stop_niter );
357    }
358    delete Mpred;
359}
360
361void merger_mix::from_setting ( const Setting& set ) {
362    MergerDiscrete::from_setting ( set );
363    Ncoms=DFLT_Ncoms;
364    UI::get( Ncoms, set, "ncoms", UI::optional );
365    effss_coef=DFLT_effss_coef;
366    UI::get (effss_coef , set,  "effss_coef", UI::optional);
367    stop_niter=10;
368    UI::get ( stop_niter, set,"stop_niter", UI::optional );
369}
370
371void merger_mix::to_setting  (Setting  &set) const  {
372    MergerDiscrete::to_setting(set);
373    UI::save( Ncoms, set, "ncoms");
374    UI::save (effss_coef , set,  "effss_coef");
375    UI::save ( stop_niter, set,"stop_niter");
376}
377
378void merger_mix::validate() {
379    MergerDiscrete::validate();
380    bdm_assert(Ncoms>0,"Ncoms too small");
381}
382
383// DEFAULTS FOR MERGER_MIX
384const int merger_mix::DFLT_Ncoms = 10;
385const double merger_mix::DFLT_effss_coef = 0.9;
386
387}
Note: See TracBrowser for help on using the browser.