Changeset 379 for bdm/estim/merger.cpp

Show
Ignore:
Timestamp:
06/17/09 23:53:11 (15 years ago)
Author:
smidl
Message:

merger restructured

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • bdm/estim/merger.cpp

    r311 r379  
    55namespace bdm 
    66{ 
    7         vec merger::lognorm_merge ( mat &lW ) 
     7        vec merger_base::merge_points ( mat &lW ) 
    88        { 
    99                int nu=lW.rows(); 
    10                 vec mu = sum ( lW ) /nu; //mean of logs 
    11                 // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe! 
    12                 vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) ); 
    13                 double coef=0.0; 
    14                 vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere 
    15                 switch ( nu ) 
     10                switch ( METHOD ) 
    1611                { 
    17                         case 2: 
    18                                 coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) ); 
    19                                 return  coef*sq2bl + mu ; 
     12                        case ARITHMETIC: 
     13                                return log ( sum ( exp ( lW ) ) ); //ugly! 
    2014                                break; 
    21                         case 3://Ratio of Bessel 
    22                                 coef = sqrt ( ( 3*beta-2 ) /3*beta ); 
    23                                 return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu; 
     15                        case GEOMETRIC: 
     16                                return sum ( lW ) /nu; 
    2417                                break; 
    25                         case 4: 
    26                                 break; 
    27                         default: // Approximate conditional density 
     18                        case LOGNORMAL: 
     19                                vec mu = sum ( lW ) /nu; //mean of logs 
     20                                // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe! 
     21                                vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) ); 
     22                                double coef=0.0; 
     23                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere 
     24                                switch ( nu ) 
     25                                { 
     26                                        case 2: 
     27                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) ); 
     28                                                return  coef*sq2bl + mu ; 
     29                                                break; 
     30                                        case 3://Ratio of Bessel 
     31                                                coef = sqrt ( ( 3*beta-2 ) /3*beta ); 
     32                                                return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu; 
     33                                                break; 
     34                                        case 4: 
     35                                                break; 
     36                                        default: // Approximate conditional density 
     37                                                break; 
     38                                } 
    2839                                break; 
    2940                } 
     
    3142        } 
    3243 
    33         void merger::merge ( const epdf* g0 ) 
     44        void merger_mix::merge ( ) 
    3445        { 
    35  
    36                 it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" ); 
    37                 //Empirical density - samples 
    38                 if ( !fix_smp ) 
    39                 { 
    40                         eSmp.set_statistics ( ones ( Ns ), g0 ); 
    41                 } 
    42  
    4346                Array<vec> &Smp = eSmp._samples(); //aux 
    4447                vec &w = eSmp._w(); //aux 
    4548 
    46                 mat Smp_ex =ones ( dim +1,Ns ); // Extended samples for the ARX model - the last row is ones 
    47                 for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );} 
    48  
    49                 if ( DBG )      *dbg << Name ( "Smp_0" ) << Smp_ex; 
     49                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones 
     50                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );} 
     51 
     52                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex; 
    5053 
    5154                // Stuff for merging 
    52                 vec lw_src ( Ns ); 
    53                 vec lw_mix ( Ns ); 
    54                 vec lw ( Ns ); 
    55                 mat lW=zeros ( n,Ns ); 
     55                vec lw_src ( Npoints );         // weights of the ith source 
     56                vec lw_mix ( Npoints );         // weights of the approximating mixture 
     57                vec lw ( Npoints );                     // tmp 
     58                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources 
    5659                vec vec0 ( 0 ); 
    5760 
    5861                //initialize importance weights 
    59                 if ( !fix_smp ) 
    60                         for ( int i=0;i<Ns;i++ ) 
    61                         { 
    62                                 lw_mix ( i ) =g0->evallog ( Smp ( i ) ); 
    63                         } 
     62                lw_mix = 1.0; // assuming uniform grid density -- otherwise  
    6463 
    6564                // Initial component in the mixture model 
     
    6766                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix: 
    6867 
    69                 Mix.init ( &A0, Smp_ex, Nc ); 
     68                Mix.init ( &A0, Smp_ex, Ncoms ); 
    7069                //Preserve initial mixture for repetitive estimation via flattening 
    7170                MixEF Mix_init ( Mix ); 
     
    7473                bool converged=false; 
    7574                int niter = 0; 
    76                 char str[100]; 
     75                char dbg_str[100]; 
    7776 
    7877                emix* Mpred=Mix.epredictor ( ); 
    79                 vec Mix_pdf ( Ns ); 
     78                vec Mix_pdf ( Npoints ); 
    8079                while ( !converged ) 
    8180                { 
     
    8382                        //Re-Initialize Mixture model 
    8483                        Mix.flatten ( &Mix_init ); 
    85                         Mix.bayesB ( Smp_ex, w*Ns ); 
     84                        Mix.bayesB ( Smp_ex, w*Npoints ); 
    8685                        delete Mpred; 
    8786                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!! 
     
    8988 
    9089                        // This will be active only later in iterations!!! 
    91                         if ( ( !fix_smp ) & ( 1./sum_sqr ( w ) <effss_coef*Ns ) ) 
     90                        if (  1./sum_sqr ( w ) <effss_coef*Npoints )  
    9291                        { 
    9392                                // Generate new samples 
    9493                                eSmp.set_samples ( Mpred ); 
    95                                 for ( int i=0;i<Ns;i++ ) 
     94                                for ( int i=0;i<Npoints;i++ ) 
    9695                                { 
    9796                                        //////////// !!!!!!!!!!!!! 
    98                                         if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; } 
     97                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; } 
    9998                                        set_col_part ( Smp_ex,i,Smp ( i ) ); 
    10099                                        //Importance of the mixture 
     
    102101                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) ); 
    103102                                } 
    104                                 if ( 0 ) 
     103                                if ( DBG ) 
    105104                                { 
    106105                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl; 
    107106                                        cout << Mix._e()->mean() <<endl; 
    108                                         cout << sum ( Smp_ex,2 ) /Ns <<endl; 
    109                                         cout << Smp_ex*Smp_ex.T() /Ns << endl; 
     107                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl; 
     108                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl; 
    110109                                } 
    111110                        } 
    112111                        if ( DBG ) 
    113112                        { 
    114                                 sprintf ( str,"Mpred_mean%d",niter ); 
    115                                 *dbg << Name ( str ) << Mpred->mean(); 
    116                                 sprintf ( str,"Mpred_var%d",niter ); 
    117                                 *dbg << Name ( str ) << Mpred->variance(); 
    118  
    119  
    120                                 sprintf ( str,"Mpdf%d",niter ); 
    121                                 for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );} 
    122                                 *dbg << Name ( str ) << Mix_pdf; 
    123  
    124                                 sprintf ( str,"Smp%d",niter ); 
    125                                 *dbg << Name ( str ) << Smp_ex; 
     113                                sprintf ( dbg_str,"Mpred_mean%d",niter ); 
     114                                *dbg_file << Name ( dbg_str ) << Mpred->mean(); 
     115                                sprintf ( dbg_str,"Mpred_var%d",niter ); 
     116                                *dbg_file << Name ( dbg_str ) << Mpred->variance(); 
     117 
     118 
     119                                sprintf ( dbg_str,"Mpdf%d",niter ); 
     120                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );} 
     121                                *dbg_file << Name ( dbg_str ) << Mix_pdf; 
     122 
     123                                sprintf ( dbg_str,"Smp%d",niter ); 
     124                                *dbg_file << Name ( dbg_str ) << Smp_ex; 
    126125 
    127126                        } 
    128127                        //Importace weighting 
    129                         for ( int i=0;i<n;i++ ) 
     128                        for ( int i=0;i<mpdfs.length();i++ ) 
    130129                        { 
    131130                                lw_src=0.0; 
     
    135134                                { 
    136135                                        // no need for conditioning or marginalization 
    137                                         for ( int j=0;j<Ns; j++ )   // Smp is Array<> => for cycle 
     136                                        for ( int j=0;j<Npoints; j++ )   // Smp is Array<> => for cycle 
    138137                                        { 
    139138                                                lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) ); 
     
    148147                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() ); 
    149148                                                //compute vector of lw_src 
    150                                                 for ( int k=0;k<Ns;k++ ) 
     149                                                for ( int k=0;k<Npoints;k++ ) 
    151150                                                { 
    152151                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond 
     
    167166                                                // Compute likelihood 
    168167                                                vec lw_dbg=lw_src; 
    169                                                 for ( int k= 0; k<Ns; k++ ) 
     168                                                for ( int k= 0; k<Npoints; k++ ) 
    170169                                                { 
    171170                                                        lw_src ( k ) += log ( 
     
    181180                                        } 
    182181                                        // Compute likelihood of the partial source 
    183                                         for ( int k= 0; k<Ns; k++ ) 
     182                                        for ( int k= 0; k<Npoints; k++ ) 
    184183                                        { 
    185184                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) ); 
     
    188187 
    189188                                } 
    190 //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad"); 
     189                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad"); 
    191190                                lW.set_row ( i, lw_src ); // do not divide by mix 
    192191                        } 
    193                         lw = lognorm_merge ( lW ); //merge 
     192                        lw = merger_base::merge_points ( lW ); //merge 
    194193 
    195194                        //Importance weighting 
    196                         if ( !fix_smp ) 
    197                                 lw -=  lw_mix; // hoping that it is not numerically sensitive... 
     195                        lw -=  lw_mix; // hoping that it is not numerically sensitive... 
    198196                        w = exp ( lw-max ( lw ) ); 
    199197 
     
    212210                        if ( DBG ) 
    213211                        { 
    214                                 sprintf ( str,"lW%d",niter ); 
    215                                 *dbg << Name ( str ) << lW; 
    216                                 sprintf ( str,"w%d",niter ); 
    217                                 *dbg << Name ( str ) << w; 
    218                                 sprintf ( str,"lw_m%d",niter ); 
    219                                 *dbg << Name ( str ) << lw_mix; 
     212                                sprintf ( dbg_str,"lW%d",niter ); 
     213                                *dbg_file << Name ( dbg_str ) << lW; 
     214                                sprintf ( dbg_str,"w%d",niter ); 
     215                                *dbg_file << Name ( dbg_str ) << w; 
     216                                sprintf ( dbg_str,"lw_m%d",niter ); 
     217                                *dbg_file << Name ( dbg_str ) << lw_mix; 
    220218                        } 
    221219                        // ==== stopping rule ===