root/library/bdm/stat/merger.cpp @ 404

Revision 404, 7.0 kB (checked in by smidl, 15 years ago)

Change in epdf: evallog returns -inf for points out of support. Merger is aware of it now.

  • Property svn:eol-style set to native
RevLine 
[262]1
[176]2#include "merger.h"
[384]3#include "../estim/arx.h"
[176]4
[299]5namespace bdm
6{
[404]7        vec merger_base::merge_points ( mat &lW ) {
[299]8                int nu=lW.rows();
[404]9                vec result;
10                ivec indW;
11                bool infexist;
12                switch ( METHOD ) {
[384]13                        case ARITHMETIC:
[404]14                                result= log ( sum ( exp ( lW ) ) ); //ugly!
[299]15                                break;
[384]16                        case GEOMETRIC:
[404]17                                result= sum ( lW ) /nu;
[299]18                                break;
[384]19                        case LOGNORMAL:
[404]20                                vec sumlW=sum ( lW ) ;
21                                indW=find((sumlW<inf) & (sumlW>-inf));
22                                infexist=(indW.size()<lW.cols());
23                                vec mu;
24                                vec lam;
25                                if (infexist){
26                                        mu = sumlW(indW) /nu; //mean of logs
27                                        //
28                                        mat validlW=lW.get_cols(indW);
29                                        lam = sum ( pow ( validlW-outer_product ( ones ( validlW.rows() ),mu ),2 ) );
30                                }
31                                else {
32                                        mu = sum ( lW ) /nu; //mean of logs
33                                        lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
34                                }
35                                //
[384]36                                double coef=0.0;
37                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
[404]38                                switch ( nu ) {
[384]39                                        case 2:
40                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
[404]41                                                result =coef*sq2bl + mu ;
[384]42                                                break;
[404]43                                        // case 4: == can be done similar to case 2 - is it worth it???
44                                        default: // see accompanying document merge_lognorm_derivation.lyx
45                                                coef = sqrt(1-(nu+1)/(2*beta*nu));
46                                                result =log(besselk((nu-3)/2, sq2bl*coef))-log(besselk((nu-3)/2, sq2bl)) + mu;
[384]47                                                break;
48                                }
[299]49                                break;
50                }
[404]51                if (infexist){
52                        vec tmp =-inf*ones(lW.cols());
53                        set_subvector(tmp, indW, result);
54                        return tmp;
55                }
56                else {return result;}
[176]57        }
58
[384]59        void merger_mix::merge ( )
[299]60        {
61                Array<vec> &Smp = eSmp._samples(); //aux
62                vec &w = eSmp._w(); //aux
[180]63
[384]64                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones
65                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );}
[180]66
[384]67                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
[180]68
[299]69                // Stuff for merging
[384]70                vec lw_src ( Npoints );         // weights of the ith source
71                vec lw_mix ( Npoints );         // weights of the approximating mixture
72                vec lw ( Npoints );                     // tmp
73                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources
[299]74                vec vec0 ( 0 );
[176]75
[300]76                //initialize importance weights
[384]77                lw_mix = 1.0; // assuming uniform grid density -- otherwise
[300]78
[299]79                // Initial component in the mixture model
80                mat V0=1e-8*eye ( dim +1 );
81                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
[176]82
[384]83                Mix.init ( &A0, Smp_ex, Ncoms );
[299]84                //Preserve initial mixture for repetitive estimation via flattening
85                MixEF Mix_init ( Mix );
[197]86
[299]87                // ============= MAIN LOOP ==================
88                bool converged=false;
89                int niter = 0;
[384]90                char dbg_str[100];
[182]91
[299]92                emix* Mpred=Mix.epredictor ( );
[384]93                vec Mix_pdf ( Npoints );
[299]94                while ( !converged )
[311]95                {
96                        //Re-estimate Mix
97                        //Re-Initialize Mixture model
98                        Mix.flatten ( &Mix_init );
[384]99                        Mix.bayesB ( Smp_ex, w*Npoints );
[299]100                        delete Mpred;
101                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
102                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
[213]103
[299]104                        // This will be active only later in iterations!!!
[384]105                        if (  1./sum_sqr ( w ) <effss_coef*Npoints ) 
[299]106                        {
107                                // Generate new samples
108                                eSmp.set_samples ( Mpred );
[384]109                                for ( int i=0;i<Npoints;i++ )
[299]110                                {
111                                        //////////// !!!!!!!!!!!!!
[384]112                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
[299]113                                        set_col_part ( Smp_ex,i,Smp ( i ) );
[300]114                                        //Importance of the mixture
[310]115                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
[311]116                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
[299]117                                }
[384]118                                if ( DBG )
[299]119                                {
120                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
[310]121                                        cout << Mix._e()->mean() <<endl;
[384]122                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl;
123                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl;
[299]124                                }
[204]125                        }
[299]126                        if ( DBG )
127                        {
[384]128                                sprintf ( dbg_str,"Mpred_mean%d",niter );
129                                *dbg_file << Name ( dbg_str ) << Mpred->mean();
130                                sprintf ( dbg_str,"Mpred_var%d",niter );
131                                *dbg_file << Name ( dbg_str ) << Mpred->variance();
[197]132
[205]133
[384]134                                sprintf ( dbg_str,"Mpdf%d",niter );
135                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
136                                *dbg_file << Name ( dbg_str ) << Mix_pdf;
[180]137
[384]138                                sprintf ( dbg_str,"Smp%d",niter );
139                                *dbg_file << Name ( dbg_str ) << Smp_ex;
[180]140
[176]141                        }
[299]142                        //Importace weighting
[384]143                        for ( int i=0;i<mpdfs.length();i++ )
[299]144                        {
145                                lw_src=0.0;
146                                //======== Same RVs ===========
147                                //Split according to dependency in rvs
148                                if ( mpdfs ( i )->dimension() ==dim )
149                                {
150                                        // no need for conditioning or marginalization
[395]151                                        lw_src = mpdfs ( i )->_epdf().evallog_m ( Smp  );
[299]152                                }
153                                else
154                                {
155                                        // compute likelihood of marginal on the conditional variable
156                                        if ( mpdfs ( i )->dimensionc() >0 )
157                                        {
158                                                // Make marginal on rvc_i
159                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
160                                                //compute vector of lw_src
[384]161                                                for ( int k=0;k<Npoints;k++ )
[299]162                                                {
163                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
164                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
165                                                }
166                                                delete tmp_marg;
[198]167
168//                                      sprintf ( str,"marg%d",niter );
[299]169//                                      *dbg << Name ( str ) << lw_src;
[198]170
[299]171                                        }
172                                        // Compute likelihood of the missing variable
173                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
174                                        {
175                                                ///////////////
176                                                // There are variales unknown to mpdfs(i) : rvzs
177                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
178                                                // Compute likelihood
179                                                vec lw_dbg=lw_src;
[384]180                                                for ( int k= 0; k<Npoints; k++ )
[299]181                                                {
182                                                        lw_src ( k ) += log (
183                                                                            tmp_cond->evallogcond (
184                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
185                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
186                                                        if ( !std::isfinite ( lw_src ( k ) ) )
187                                                        {
188                                                                lw_src ( k ) = -1e16; cout << "!";
189                                                        }
[204]190                                                }
[299]191                                                delete tmp_cond;
[182]192                                        }
[299]193                                        // Compute likelihood of the partial source
[384]194                                        for ( int k= 0; k<Npoints; k++ )
[299]195                                        {
196                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
197                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
198                                        }
199
[182]200                                }
[384]201                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
[299]202                                lW.set_row ( i, lw_src ); // do not divide by mix
203                        }
[384]204                        lw = merger_base::merge_points ( lW ); //merge
[197]205
[299]206                        //Importance weighting
[384]207                        lw -=  lw_mix; // hoping that it is not numerically sensitive...
[299]208                        w = exp ( lw-max ( lw ) );
[300]209
[299]210                        //renormalize
211                        double sumw =sum ( w );
212                        if ( std::isfinite ( sumw ) )
213                        {
214                                w = w/sumw;
215                        }
216                        else
217                        {
218                                it_file itf ( "merg_err.it" );
219                                itf << Name ( "w" ) << w;
220                        }
[180]221
[299]222                        if ( DBG )
223                        {
[384]224                                sprintf ( dbg_str,"lW%d",niter );
225                                *dbg_file << Name ( dbg_str ) << lW;
226                                sprintf ( dbg_str,"w%d",niter );
227                                *dbg_file << Name ( dbg_str ) << w;
228                                sprintf ( dbg_str,"lw_m%d",niter );
229                                *dbg_file << Name ( dbg_str ) << lw_mix;
[299]230                        }
231                        // ==== stopping rule ===
232                        niter++;
[395]233                        converged = ( niter>stop_niter );
[204]234                }
[299]235                delete Mpred;
236//              cout << endl;
[205]237
[176]238        }
239
[399]240        // DEFAULTS FOR MERGER_BASE
241        const MERGER_METHOD merger_base::DFLT_METHOD=LOGNORMAL;
242        const double merger_base::DFLT_beta=1.2;
243        // DEFAULTS FOR MERGER_MIX
244        const int merger_mix::DFLT_Ncoms=10;
245        const double merger_mix::DFLT_effss_coef=0.5;
246
[176]247}
Note: See TracBrowser for help on using the browser.