root/library/bdm/stat/merger.cpp @ 504

Revision 504, 7.0 kB (checked in by vbarta, 15 years ago)

returning shared pointers from epdf::marginal & epdf::condition; testsuite run leaks down from 8402 to 6510 bytes

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm {
6
7merger_base::merger_base ( const Array<mpdf*> &S, bool own ) {
8        DBG = false;
9        dbg_file = NULL;
10        set_sources ( S, own );
11}
12
13vec merger_base::merge_points ( mat &lW ) {
14        int nu = lW.rows();
15        vec result;
16        ivec indW;
17        bool infexist;
18        switch ( METHOD ) {
19        case ARITHMETIC:
20                result = log ( sum ( exp ( lW ) ) ); //ugly!
21                break;
22        case GEOMETRIC:
23                result = sum ( lW ) / nu;
24                break;
25        case LOGNORMAL:
26                vec sumlW = sum ( lW ) ;
27                indW = find ( ( sumlW < inf ) & ( sumlW > -inf ) );
28                infexist = ( indW.size() < lW.cols() );
29                vec mu;
30                vec lam;
31                if ( infexist ) {
32                        mu = sumlW ( indW ) / nu; //mean of logs
33                        //
34                        mat validlW = lW.get_cols ( indW );
35                        lam = sum ( pow ( validlW - outer_product ( ones ( validlW.rows() ), mu ), 2 ) );
36                } else {
37                        mu = sum ( lW ) / nu; //mean of logs
38                        lam = sum ( pow ( lW - outer_product ( ones ( lW.rows() ), mu ), 2 ) );
39                }
40                //
41                double coef = 0.0;
42                vec sq2bl = sqrt ( 2 * beta * lam ); //this term is everywhere
43                switch ( nu ) {
44                case 2:
45                        coef = ( 1 - 0.5 * sqrt ( ( 4.0 * beta - 3.0 ) / beta ) );
46                        result = coef * sq2bl + mu ;
47                        break;
48                        // case 4: == can be done similar to case 2 - is it worth it???
49                default: // see accompanying document merge_lognorm_derivation.lyx
50                        coef = sqrt ( 1 - ( nu + 1 ) / ( 2 * beta * nu ) );
51                        result = log ( besselk ( ( nu - 3 ) / 2, sq2bl * coef ) ) - log ( besselk ( ( nu - 3 ) / 2, sq2bl ) ) + mu;
52                        break;
53                }
54                break;
55        }
56        if ( infexist ) {
57                vec tmp = -inf * ones ( lW.cols() );
58                set_subvector ( tmp, indW, result );
59                return tmp;
60        } else {
61                return result;
62        }
63}
64
65void merger_mix::merge ( ) {
66        Array<vec> &Smp = eSmp._samples(); //aux
67        vec &w = eSmp._w(); //aux
68
69        mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones
70        for ( int i = 0; i < Npoints; i++ ) {
71                set_col_part ( Smp_ex, i, Smp ( i ) );
72        }
73
74        if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
75
76        // Stuff for merging
77        vec lw_src ( Npoints );         // weights of the ith source
78        vec lw_mix ( Npoints );         // weights of the approximating mixture
79        vec lw ( Npoints );                     // tmp
80        mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources
81        vec vec0 ( 0 );
82
83        //initialize importance weights
84        lw_mix = 1.0; // assuming uniform grid density -- otherwise
85
86        // Initial component in the mixture model
87        mat V0 = 1e-8 * eye ( dim + 1 );
88        ARX A0;
89        A0.set_statistics ( dim, V0 ); //initial guess of Mix:
90
91        Mix.init ( &A0, Smp_ex, Ncoms );
92        //Preserve initial mixture for repetitive estimation via flattening
93        MixEF Mix_init ( Mix );
94
95        // ============= MAIN LOOP ==================
96        bool converged = false;
97        int niter = 0;
98        char dbg_str[100];
99
100        emix* Mpred = Mix.epredictor ( );
101        vec Mix_pdf ( Npoints );
102        while ( !converged ) {
103                //Re-estimate Mix
104                //Re-Initialize Mixture model
105                Mix.flatten ( &Mix_init );
106                Mix.bayesB ( Smp_ex, w*Npoints );
107                delete Mpred;
108                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
109                Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
110
111                // This will be active only later in iterations!!!
112                if ( 1. / sum_sqr ( w ) < effss_coef*Npoints ) {
113                        // Generate new samples
114                        eSmp.set_samples ( Mpred );
115                        for ( int i = 0; i < Npoints; i++ ) {
116                                //////////// !!!!!!!!!!!!!
117                                //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
118                                set_col_part ( Smp_ex, i, Smp ( i ) );
119                                //Importance of the mixture
120                                //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
121                                lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
122                        }
123                        if ( DBG ) {
124                                cout << "Resampling =" << 1. / sum_sqr ( w ) << endl;
125                                cout << Mix._e()->mean() << endl;
126                                cout << sum ( Smp_ex, 2 ) / Npoints << endl;
127                                cout << Smp_ex*Smp_ex.T() / Npoints << endl;
128                        }
129                }
130                if ( DBG ) {
131                        sprintf ( dbg_str, "Mpred_mean%d", niter );
132                        *dbg_file << Name ( dbg_str ) << Mpred->mean();
133                        sprintf ( dbg_str, "Mpred_var%d", niter );
134                        *dbg_file << Name ( dbg_str ) << Mpred->variance();
135
136
137                        sprintf ( dbg_str, "Mpdf%d", niter );
138                        for ( int i = 0; i < Npoints; i++ ) {
139                                Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );
140                        }
141                        *dbg_file << Name ( dbg_str ) << Mix_pdf;
142
143                        sprintf ( dbg_str, "Smp%d", niter );
144                        *dbg_file << Name ( dbg_str ) << Smp_ex;
145
146                }
147                //Importace weighting
148                for ( int i = 0; i < mpdfs.length(); i++ ) {
149                        lw_src = 0.0;
150                        //======== Same RVs ===========
151                        //Split according to dependency in rvs
152                        if ( mpdfs ( i )->dimension() == dim ) {
153                                // no need for conditioning or marginalization
154                                lw_src = mpdfs ( i )->evallogcond_m ( Smp , vec(0));
155                        } else {
156                                // compute likelihood of marginal on the conditional variable
157                                if ( mpdfs ( i )->dimensionc() > 0 ) {
158                                        // Make marginal on rvc_i
159                                        shared_ptr<epdf> tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
160                                        //compute vector of lw_src
161                                        for ( int k = 0; k < Npoints; k++ ) {
162                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
163                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
164                                        }
165
166//                                      sprintf ( str,"marg%d",niter );
167//                                      *dbg << Name ( str ) << lw_src;
168
169                                }
170                                // Compute likelihood of the missing variable
171                                if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) ) {
172                                        ///////////////
173                                        // There are variales unknown to mpdfs(i) : rvzs
174                                        shared_ptr<mpdf> tmp_cond = Mpred->condition ( rvzs ( i ) );
175                                        // Compute likelihood
176                                        vec lw_dbg = lw_src;
177                                        for ( int k = 0; k < Npoints; k++ ) {
178                                                lw_src ( k ) += log (
179                                                                    tmp_cond->evallogcond (
180                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
181                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
182                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
183                                                        lw_src ( k ) = -1e16;
184                                                        cout << "!";
185                                                }
186                                        }
187                                }
188                                // Compute likelihood of the partial source
189                                for ( int k = 0; k < Npoints; k++ ) {
190                                        lw_src ( k ) += mpdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ), 
191                                                         dls ( i )->get_cond ( Smp ( k ) ));
192                                }
193
194                        }
195                        //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
196                        lW.set_row ( i, lw_src ); // do not divide by mix
197                }
198                lw = merger_base::merge_points ( lW ); //merge
199
200                //Importance weighting
201                lw -=  lw_mix; // hoping that it is not numerically sensitive...
202                w = exp ( lw - max ( lw ) );
203
204                //renormalize
205                double sumw = sum ( w );
206                if ( std::isfinite ( sumw ) ) {
207                        w = w / sumw;
208                } else {
209                        it_file itf ( "merg_err.it" );
210                        itf << Name ( "w" ) << w;
211                }
212
213                if ( DBG ) {
214                        sprintf ( dbg_str, "lW%d", niter );
215                        *dbg_file << Name ( dbg_str ) << lW;
216                        sprintf ( dbg_str, "w%d", niter );
217                        *dbg_file << Name ( dbg_str ) << w;
218                        sprintf ( dbg_str, "lw_m%d", niter );
219                        *dbg_file << Name ( dbg_str ) << lw_mix;
220                }
221                // ==== stopping rule ===
222                niter++;
223                converged = ( niter > stop_niter );
224        }
225        delete Mpred;
226//              cout << endl;
227
228}
229
230// DEFAULTS FOR MERGER_BASE
231const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL;
232const double merger_base::DFLT_beta = 1.2;
233// DEFAULTS FOR MERGER_MIX
234const int merger_mix::DFLT_Ncoms = 10;
235const double merger_mix::DFLT_effss_coef = 0.5;
236
237}
Note: See TracBrowser for help on using the browser.