root/library/bdm/stat/merger.cpp @ 679

Revision 679, 6.9 kB (checked in by smidl, 15 years ago)

Major changes in BM -- OK is only test suite and tests/tutorial -- the rest is broken!!!

  • Property svn:eol-style set to native
RevLine 
[262]1
[176]2#include "merger.h"
[384]3#include "../estim/arx.h"
[176]4
[477]5namespace bdm {
[423]6
[507]7merger_base::merger_base ( const Array<shared_ptr<mpdf> > &S ):
8        Npoints(0), DBG(false), dbg_file(0) {
9        set_sources ( S );
[477]10}
[423]11
[477]12vec merger_base::merge_points ( mat &lW ) {
13        int nu = lW.rows();
14        vec result;
15        ivec indW;
[561]16        bool infexist = false;
[477]17        switch ( METHOD ) {
18        case ARITHMETIC:
19                result = log ( sum ( exp ( lW ) ) ); //ugly!
20                break;
21        case GEOMETRIC:
22                result = sum ( lW ) / nu;
23                break;
24        case LOGNORMAL:
25                vec sumlW = sum ( lW ) ;
26                indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
27                infexist = ( indW.size() < lW.cols() );
28                vec mu;
29                vec lam;
30                if ( infexist ) {
31                        mu = sumlW ( indW ) / nu; //mean of logs
32                        //
33                        mat validlW = lW.get_cols ( indW );
34                        lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
35                } else {
36                        mu = sum ( lW ) / nu; //mean of logs
37                        lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
[299]38                }
[477]39                //
40                double coef = 0.0;
41                vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
42                switch ( nu ) {
43                case 2:
44                        coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
45                        result = coef * sq2bl + mu ;
46                        break;
47                        // case 4: == can be done similar to case 2 - is it worth it???
48                default: // see accompanying document merge_lognorm_derivation.lyx
49                        coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
50                        result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
51                        break;
[404]52                }
[477]53                break;
[176]54        }
[477]55        if ( infexist ) {
56                vec tmp = -inf * ones ( lW.cols() );
57                set_subvector ( tmp, indW, result );
58                return tmp;
59        } else {
60                return result;
61        }
62}
[176]63
[477]64void merger_mix::merge ( ) {
65        Array<vec> &Smp = eSmp._samples(); //aux
66        vec &w = eSmp._w(); //aux
[180]67
[477]68        mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
69        for ( int i = 0; i < Npoints; i++ ) {
70                set_col_part ( Smp_ex, i, Smp ( i ) );
71        }
[180]72
[477]73        if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
[180]74
[477]75        // Stuff for merging
76        vec lw_src ( Npoints );         // weights of the ith source
77        vec lw_mix ( Npoints );         // weights of the approximating mixture
78        vec lw ( Npoints );                     // tmp
79        mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
80        vec vec0 ( 0 );
[176]81
[477]82        //initialize importance weights
83        lw_mix = 1.0; // assuming uniform grid density -- otherwise
[300]84
[477]85        // Initial component in the mixture model
86        mat V0 = 1e-8 * eye ( dim + 1 );
87        ARX A0;
88        A0.set_statistics ( dim, V0 ); //initial guess of Mix:
[176]89
[477]90        Mix.init ( &A0, Smp_ex, Ncoms );
91        //Preserve initial mixture for repetitive estimation via flattening
92        MixEF Mix_init ( Mix );
[197]93
[477]94        // ============= MAIN LOOP ==================
95        bool converged = false;
96        int niter = 0;
97        char dbg_str[100];
[182]98
[477]99        emix* Mpred = Mix.epredictor ( );
100        vec Mix_pdf ( Npoints );
101        while ( !converged ) {
102                //Re-estimate Mix
103                //Re-Initialize Mixture model
104                Mix.flatten ( &Mix_init );
[679]105                Mix.bayesB ( Smp_ex,empty_vec, w*Npoints );
[477]106                delete Mpred;
107                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
108                Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
[213]109
[477]110                // This will be active only later in iterations!!!
111                if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
112                        // Generate new samples
113                        eSmp.set_samples ( Mpred );
114                        for ( int i = 0; i < Npoints; i++ ) {
115                                //////////// !!!!!!!!!!!!!
116                                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
117                                set_col_part ( Smp_ex, i, Smp ( i ) );
118                                //Importance of the mixture
119                                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
120                                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
[204]121                        }
[477]122                        if ( DBG ) {
123                                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
[536]124                                cout << Mix.posterior().mean() << endl;
[477]125                                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
126                                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
127                        }
128                }
129                if ( DBG ) {
130                        sprintf ( dbg_str, "Mpred_mean%d", niter );
131                        *dbg_file << Name ( dbg_str ) << Mpred->mean();
132                        sprintf ( dbg_str, "Mpred_var%d", niter );
133                        *dbg_file << Name ( dbg_str ) << Mpred->variance();
[197]134
[205]135
[477]136                        sprintf ( dbg_str, "Mpdf%d", niter );
137                        for ( int i = 0; i < Npoints; i++ ) {
138                                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );
139                        }
140                        *dbg_file << Name ( dbg_str ) << Mix_pdf;
[180]141
[477]142                        sprintf ( dbg_str, "Smp%d", niter );
143                        *dbg_file << Name ( dbg_str ) << Smp_ex;
[180]144
[477]145                }
146                //Importace weighting
147                for ( int i = 0; i < mpdfs.length(); i++ ) {
148                        lw_src = 0.0;
149                        //======== Same RVs ===========
150                        //Split according to dependency in rvs
151                        if ( mpdfs ( i )->dimension() == dim ) {
152                                // no need for conditioning or marginalization
[487]153                                lw_src = mpdfs ( i )->evallogcond_m ( Smp , vec(0));
[477]154                        } else {
155                                // compute likelihood of marginal on the conditional variable
156                                if ( mpdfs ( i )->dimensionc() > 0 ) {
157                                        // Make marginal on rvc_i
[504]158                                        shared_ptr<epdf> tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
[477]159                                        //compute vector of lw_src
160                                        for ( int k = 0; k < Npoints; k++ ) {
161                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
162                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
163                                        }
[198]164
165//                                      sprintf ( str,"marg%d",niter );
[299]166//                                      *dbg << Name ( str ) << lw_src;
[198]167
[477]168                                }
169                                // Compute likelihood of the missing variable
170                                if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) ) {
171                                        ///////////////
172                                        // There are variales unknown to mpdfs(i) : rvzs
[504]173                                        shared_ptr<mpdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
[477]174                                        // Compute likelihood
175                                        vec lw_dbg = lw_src;
176                                        for ( int k = 0; k < Npoints; k++ ) {
177                                                lw_src ( k ) += log (
178                                                                    tmp_cond->evallogcond (
179                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
180                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
181                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
182                                                        lw_src ( k ) = -1e16;
183                                                        cout << "!";
[204]184                                                }
[182]185                                        }
186                                }
[477]187                                // Compute likelihood of the partial source
188                                for ( int k = 0; k < Npoints; k++ ) {
[487]189                                        lw_src ( k ) += mpdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ), 
190                                                         dls ( i )->get_cond ( Smp ( k ) ));
[477]191                                }
192
[299]193                        }
[565]194
[477]195                        lW.set_row ( i, lw_src ); // do not divide by mix
196                }
197                lw = merger_base::merge_points ( lW ); //merge
[197]198
[477]199                //Importance weighting
200                lw -=  lw_mix; // hoping that it is not numerically sensitive...
201                w = exp ( lw - max ( lw ) );
[300]202
[477]203                //renormalize
204                double sumw = sum ( w );
205                if ( std::isfinite ( sumw ) ) {
206                        w = w / sumw;
207                } else {
208                        it_file itf ( "merg_err.it" );
209                        itf << Name ( "w" ) << w;
210                }
[180]211
[477]212                if ( DBG ) {
213                        sprintf ( dbg_str, "lW%d", niter );
214                        *dbg_file << Name ( dbg_str ) << lW;
215                        sprintf ( dbg_str, "w%d", niter );
216                        *dbg_file << Name ( dbg_str ) << w;
217                        sprintf ( dbg_str, "lw_m%d", niter );
218                        *dbg_file << Name ( dbg_str ) << lw_mix;
[204]219                }
[477]220                // ==== stopping rule ===
221                niter++;
222                converged = ( niter > stop_niter );
223        }
224        delete Mpred;
[299]225//              cout << endl;
[205]226
[477]227}
[176]228
[477]229// DEFAULTS FOR MERGER_BASE
230const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
231const double merger_base::DFLT_beta = 1.2;
232// DEFAULTS FOR MERGER_MIX
233const int merger_mix::DFLT_Ncoms = 10;
234const double merger_mix::DFLT_effss_coef = 0.5;
[399]235
[176]236}
Note: See TracBrowser for help on using the browser.