root/library/bdm/stat/merger.cpp @ 957

Revision 957, 11.8 kB (checked in by mido, 14 years ago)

a small patches of previous commit as arranged yesterday

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) :
8                Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) {
9        set_sources ( S );
10}
11
12
13void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) {
14        pdfs = Sources;
15        Nsources = pdfs.length();
16        //set sizes
17        dls.set_size ( Sources.length() );
18        rvzs.set_size ( Sources.length() );
19        zdls.set_size ( Sources.length() );
20
21        rv = get_composite_rv ( pdfs, /* checkoverlap = */ false );
22
23        RV rvc;
24        // Extend rv by rvc!
25        for ( int i = 0; i < pdfs.length(); i++ ) {
26                RV rvx = pdfs ( i )->_rvc().subt ( rv );
27                rvc.add ( rvx ); // add rv to common rvc
28        }
29
30        // join rv and rvc - see descriprion
31        rv.add ( rvc );
32        // get dimension
33        dim = rv._dsize();
34
35        // create links between sources and common rv
36        RV xytmp;
37        for ( int i = 0; i < pdfs.length(); i++ ) {
38                //Establich connection between pdfs and merger
39                dls ( i ) = new datalink_m2e;
40                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv );
41
42                // find out what is missing in each pdf
43                xytmp = pdfs ( i )->_rv();
44                xytmp.add ( pdfs ( i )->_rvc() );
45                // z_i = common_rv-xy
46                rvzs ( i ) = rv.subt ( xytmp );
47                //establish connection between extension (z_i|x,y)s and common rv
48                zdls ( i ) = new datalink_m2e;
49                zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
50        };
51}
52
53void merger_base::set_support ( rectangular_support &Sup ) {
54        Npoints = Sup.points();
55        eSmp.set_parameters ( Npoints, false );
56        Array<vec> &samples = eSmp._samples();
57        eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
58        //set samples
59        samples ( 0 ) = Sup.first_vec();
60        for ( int j = 1; j < Npoints; j++ ) {
61                samples ( j ) = Sup.next_vec();
62        }
63}
64
65void merger_base::merge () {
66        validate();
67
68        //check if sources overlap:
69        bool OK = true;
70        for ( int i = 0; i < pdfs.length(); i++ ) {
71                OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
72                OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
73        }
74
75        if ( OK ) {
76                mat lW = zeros ( pdfs.length(), eSmp._w().length() );
77
78                vec emptyvec ( 0 );
79                for ( int i = 0; i < pdfs.length(); i++ ) {
80                        for ( int j = 0; j < eSmp._w().length(); j++ ) {
81                                lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
82                        }
83                }
84
85                vec w_nn = merge_points ( lW );
86                vec wtmp = exp ( w_nn - max ( w_nn ) );
87                //renormalize
88                eSmp._w() = wtmp / sum ( wtmp );
89        } else {
90                bdm_error ( "Sources are not compatible - use merger_mix" );
91        }
92}
93
94vec merger_base::merge_points ( mat &lW ) {
95        int nu = lW.rows();
96        vec result;
97        ivec indW;
98        bool infexist = false;
99        switch ( METHOD ) {
100        case ARITHMETIC:
101                result = log ( sum ( exp ( lW ) ) ); //ugly!
102                break;
103        case GEOMETRIC:
104                result = sum ( lW ) / nu;
105                break;
106        case LOGNORMAL:
107                vec sumlW = sum ( lW ) ;
108                indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
109                infexist = ( indW.size() < lW.cols() );
110                vec mu;
111                vec lam;
112                if ( infexist ) {
113                        mu = sumlW ( indW ) / nu; //mean of logs
114                        //
115                        mat validlW = lW.get_cols ( indW );
116                        lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
117                } else {
118                        mu = sum ( lW ) / nu; //mean of logs
119                        lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
120                }
121                //
122                double coef = 0.0;
123                vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
124                switch ( nu ) {
125                case 2:
126                        coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
127                        result = coef * sq2bl + mu ;
128                        break;
129                        // case 4: == can be done similar to case 2 - is it worth it???
130                default: // see accompanying document merge_lognorm_derivation.lyx
131                        coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
132                        result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
133                        break;
134                }
135                break;
136        }
137        if ( infexist ) {
138                vec tmp = -inf * ones ( lW.cols() );
139                set_subvector ( tmp, indW, result );
140                return tmp;
141        } else {
142                return result;
143        }
144}
145
146vec merger_base::mean() const {
147        const Vec<double> &w = eSmp._w();
148        const Array<vec> &S = eSmp._samples();
149        vec tmp = zeros ( dim );
150        for ( int i = 0; i < Npoints; i++ ) {
151                tmp += w ( i ) * S ( i );
152        }
153        return tmp;
154}
155
156mat merger_base::covariance() const {
157        const vec &w = eSmp._w();
158        const Array<vec> &S = eSmp._samples();
159
160        vec mea = mean();
161
162//                      cout << sum (w) << "," << w*w << endl;
163
164        mat Tmp = zeros ( dim, dim );
165        for ( int i = 0; i < Npoints; i++ ) {
166                vec tmp=S ( i )-mea; //inefficient but numerically stable
167                Tmp += w ( i ) * outer_product (tmp , tmp );
168        }
169        return Tmp;
170}
171
172vec merger_base::variance() const {
173        return eSmp.variance();
174}
175
176void merger_mix::merge ( ) {
177        if(Npoints<1){
178                set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000);
179        }
180               
181        bdm_assert(Npoints>0,"No points in support");
182        bdm_assert(Nsources>0,"No Sources");
183       
184        Array<vec> &Smp = eSmp._samples(); //aux
185        vec &w = eSmp._w(); //aux
186
187        mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
188        for ( int i = 0; i < Npoints; i++ ) {
189                set_col_part ( Smp_ex, i, Smp ( i ) );
190        }
191
192        if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
193
194        // Stuff for merging
195        vec lw_src ( Npoints );         // weights of the ith source
196        vec lw_mix ( Npoints );         // weights of the approximating mixture
197        vec lw ( Npoints );                     // tmp
198        mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
199        vec vec0 ( 0 );
200
201        //initialize importance weights
202        lw_mix = 1.0; // assuming uniform grid density -- otherwise
203
204        // Initial component in the mixture model
205        mat V0 = 1e-8 * eye ( dim + 1 );
206        ARX A0;
207        A0.set_statistics ( dim, V0 ); //initial guess of Mix:
208        A0.validate();
209
210        Mix.init ( &A0, Smp_ex, Ncoms );
211        //Preserve initial mixture for repetitive estimation via flattening
212        MixEF Mix_init ( Mix );
213
214        // ============= MAIN LOOP ==================
215        bool converged = false;
216        int niter = 0;
217        char dbg_str[100];
218
219        emix* Mpred = Mix.epredictor ( );
220        vec Mix_pdf ( Npoints );
221        while ( !converged ) {
222                //Re-estimate Mix
223                //Re-Initialize Mixture model
224                Mix.flatten ( &Mix_init );
225                Mix.bayes_batch ( Smp_ex, empty_vec, w*Npoints );
226                delete Mpred;
227                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
228                Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
229
230                // This will be active only later in iterations!!!
231                if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
232                        // Generate new samples
233                        eSmp.set_samples ( Mpred );
234                        for ( int i = 0; i < Npoints; i++ ) {
235                                //////////// !!!!!!!!!!!!!
236                                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
237                                set_col_part ( Smp_ex, i, Smp ( i ) );
238                                //Importance of the mixture
239                                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
240                                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
241                        }
242                        if ( DBG ) {
243                                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
244                                cout << Mix.posterior().mean() << endl;
245                                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
246                                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
247                        }
248                }
249                if ( DBG ) {
250                        sprintf ( dbg_str, "Mpred_mean%d", niter );
251                        *dbg_file << Name ( dbg_str ) << Mpred->mean();
252                        sprintf ( dbg_str, "Mpred_var%d", niter );
253                        *dbg_file << Name ( dbg_str ) << Mpred->variance();
254                        sprintf ( dbg_str, "Mpred_cov%d", niter );
255                        *dbg_file << Name ( dbg_str ) << covariance();
256                       
257
258                        sprintf ( dbg_str, "pdf%d", niter );
259                        for ( int i = 0; i < Npoints; i++ ) {
260                                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );
261                        }
262                        *dbg_file << Name ( dbg_str ) << Mix_pdf;
263
264                        sprintf ( dbg_str, "Smp%d", niter );
265                        *dbg_file << Name ( dbg_str ) << Smp_ex;
266
267                }
268                //Importace weighting
269                for ( int i = 0; i < pdfs.length(); i++ ) {
270                        lw_src = 0.0;
271                        //======== Same RVs ===========
272                        //Split according to dependency in rvs
273                        if ( pdfs ( i )->dimension() == dim ) {
274                                // no need for conditioning or marginalization
275                                lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) );
276                        } else {
277                                // compute likelihood of marginal on the conditional variable
278                                if ( pdfs ( i )->dimensionc() > 0 ) {
279                                        // Make marginal on rvc_i
280                                        shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() );
281                                        //compute vector of lw_src
282                                        for ( int k = 0; k < Npoints; k++ ) {
283                                                // Here val of tmp_marg = cond of pdfs(i) ==> calling dls->get_cond
284                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
285                                        }
286
287//                                      sprintf ( str,"marg%d",niter );
288//                                      *dbg << Name ( str ) << lw_src;
289
290                                }
291                                // Compute likelihood of the missing variable
292                                if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) {
293                                        ///////////////
294                                        // There are variales unknown to pdfs(i) : rvzs
295                                        shared_ptr<pdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
296                                        // Compute likelihood
297                                        vec lw_dbg = lw_src;
298                                        for ( int k = 0; k < Npoints; k++ ) {
299                                                lw_src ( k ) += log (
300                                                                    tmp_cond->evallogcond (
301                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
302                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
303                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
304                                                        lw_src ( k ) = -1e16;
305                                                        cout << "!";
306                                                }
307                                        }
308                                }
309                                // Compute likelihood of the partial source
310                                for ( int k = 0; k < Npoints; k++ ) {
311                                        lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ),
312                                                        dls ( i )->get_cond ( Smp ( k ) ) );
313                                }
314
315                        }
316
317                        lW.set_row ( i, lw_src ); // do not divide by mix
318                }
319                lw = merger_base::merge_points ( lW ); //merge
320
321                //Importance weighting
322                lw -=  lw_mix; // hoping that it is not numerically sensitive...
323                w = exp ( lw - max ( lw ) );
324
325                //renormalize
326                double sumw = sum ( w );
327                if ( std::isfinite ( sumw ) ) {
328                        w = w / sumw;
329                } else {
330                        it_file itf ( "merg_err.it" );
331                        itf << Name ( "w" ) << w;
332                }
333
334                if ( DBG ) {
335                        sprintf ( dbg_str, "lW%d", niter );
336                        *dbg_file << Name ( dbg_str ) << lW;
337                        sprintf ( dbg_str, "w%d", niter );
338                        *dbg_file << Name ( dbg_str ) << w;
339                        sprintf ( dbg_str, "lw_m%d", niter );
340                        *dbg_file << Name ( dbg_str ) << lw_mix;
341                }
342                // ==== stopping rule ===
343                niter++;
344                converged = ( niter > stop_niter );
345        }
346        delete Mpred;
347//              cout << endl;
348
349}
350
351void merger_mix::from_setting ( const Setting& set ) {
352        merger_base::from_setting ( set );
353        Ncoms=DFLT_Ncoms;
354        UI::get( Ncoms, set, "ncoms", UI::optional );
355        effss_coef=DFLT_effss_coef;
356        UI::get (effss_coef , set,  "effss_coef", UI::optional);
357        stop_niter=10;
358        UI::get ( stop_niter, set,"stop_niter", UI::optional );         
359}
360
361void    merger_mix::to_setting  (Setting  &set) const  {
362        merger_base::to_setting(set);
363        UI::save( Ncoms, set, "ncoms");
364        UI::save (effss_coef , set,  "effss_coef");
365        UI::save ( stop_niter, set,"stop_niter");
366}
367
368void merger_mix::validate() {
369        merger_base::validate();
370        bdm_assert(Ncoms>0,"Ncoms too small");
371}
372
373void merger_base::from_setting ( const Setting& set ) {
374        // get support
375        // find which method to use
376        epdf::from_setting (set);
377        string meth_str;
378        UI::get( meth_str, set, "method", UI::compulsory );
379        if ( meth_str == "arithmetic" ) 
380                set_method ( ARITHMETIC );
381        else if ( meth_str == "geometric" )             
382                set_method ( GEOMETRIC );
383        else if ( meth_str ==  "lognormal" ) { 
384                set_method ( LOGNORMAL );
385                UI::get(beta, set, "beta", UI::compulsory ); 
386        }
387       
388
389        string dbg_filename;
390        if ( UI::get ( dbg_filename, set, "dbg_file" ) )
391                set_debug_file( dbg_filename );
392
393}
394
395void merger_base::to_setting  (Setting  &set) const {
396        epdf::to_setting(set);
397               
398        UI::save( METHOD, set, "method");
399
400        if( METHOD == LOGNORMAL )
401                UI::save (beta, set, "beta" );
402
403        if( DBG ) 
404                UI::save ( dbg_file->get_fname(), set, "dbg_file" );
405}
406
407void merger_base::validate() {
408//              bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
409//              bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
410        epdf::validate();
411        bdm_assert ( isnamed(), "mergers must be named" );
412}
413
414// DEFAULTS FOR MERGER_BASE
415const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
416const double merger_base::DFLT_beta = 1.2;
417// DEFAULTS FOR MERGER_MIX
418const int merger_mix::DFLT_Ncoms = 10;
419const double merger_mix::DFLT_effss_coef = 0.9;
420
421}
Note: See TracBrowser for help on using the browser.