Show
Ignore:
Timestamp:
06/10/10 21:54:57 (14 years ago)
Author:
smidl
Message:

Changes in merger + change in loading ARX

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • library/bdm/stat/merger.cpp

    r1068 r1079  
    55namespace bdm { 
    66 
    7 merger_base::merger_base ( const Array<shared_ptr<pdf> > &S ) : 
     7MergerDiscrete::MergerDiscrete ( const Array<shared_ptr<pdf> > &S ) : 
    88    Npoints ( 0 ), DBG ( false ), dbg_file ( 0 ) { 
    99    set_sources ( S ); 
     
    1111 
    1212 
    13 void merger_base::set_sources ( const Array<shared_ptr<pdf> > &Sources ) { 
    14     pdfs = Sources; 
    15     Nsources = pdfs.length(); 
     13void MergerDiscrete::set_sources ( const Array<shared_ptr<pdf> > &PartSources ) { 
     14    part_sources = PartSources; 
     15    Nsources = part_sources.length(); 
    1616    //set sizes 
    17     dls.set_size ( Sources.length() ); 
    18     rvzs.set_size ( Sources.length() ); 
    19     zdls.set_size ( Sources.length() ); 
    20  
    21     rv = get_composite_rv ( pdfs, /* checkoverlap = */ false ); 
     17    sources.set_size ( Nsources); 
     18        dls.set_size ( Nsources ); 
     19        rvzs.set_size ( Nsources ); 
     20        zdls.set_size ( Nsources ); 
     21 
     22    RV rv = get_composite_rv ( part_sources, /* checkoverlap = */ false ); 
    2223 
    2324    RV rvc; 
    2425    // Extend rv by rvc! 
    25     for ( int i = 0; i < pdfs.length(); i++ ) { 
    26         RV rvx = pdfs ( i )->_rvc().subt ( rv ); 
     26    for ( int i = 0; i < sources.length(); i++ ) { 
     27        RV rvx = part_sources ( i )->_rvc().subt ( rv ); 
    2728        rvc.add ( rvx ); // add rv to common rvc 
    2829    } 
     
    3031    // join rv and rvc - see descriprion 
    3132    rv.add ( rvc ); 
    32     // get dimension 
    33     dim = rv._dsize(); 
    34  
     33         
    3534    // create links between sources and common rv 
    3635    RV xytmp; 
    37     for ( int i = 0; i < pdfs.length(); i++ ) { 
     36    for ( int i = 0; i < part_sources.length(); i++ ) {          
    3837        //Establich connection between pdfs and merger 
    3938        dls ( i ) = new datalink_m2e; 
    40         dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), rv ); 
     39                dls ( i )->set_connection ( part_sources ( i )->_rv(), part_sources( i )->_rvc(), rv ); 
    4140 
    4241        // find out what is missing in each pdf 
    43         xytmp = pdfs ( i )->_rv(); 
    44         xytmp.add ( pdfs ( i )->_rvc() ); 
     42        xytmp = part_sources ( i )->_rv(); 
     43                xytmp.add ( part_sources ( i )->_rvc() ); 
    4544        // z_i = common_rv-xy 
    4645        rvzs ( i ) = rv.subt ( xytmp ); 
     
    4948        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ; 
    5049    }; 
    51 } 
    52  
    53 void merger_base::set_support ( rectangular_support &Sup ) { 
     50         
     51        //  
     52        merger().set_rv(rv); 
     53} 
     54 
     55void MergerDiscrete::set_support ( rectangular_support &Sup ) { 
    5456    Npoints = Sup.points(); 
    5557    eSmp.set_parameters ( Npoints, false ); 
     
    6163        samples ( j ) = Sup.next_vec(); 
    6264    } 
    63 } 
    64  
    65 void merger_base::merge () { 
     65    eSmp.validate(); 
     66} 
     67 
     68void MergerDiscrete::merge () { 
    6669    validate(); 
    6770 
    6871    //check if sources overlap: 
    6972    bool OK = true; 
    70     for ( int i = 0; i < pdfs.length(); i++ ) { 
     73    for ( int i = 0; i < part_sources.length(); i++ ) { 
    7174        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty 
    72         OK &= ( pdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty 
     75        OK &= ( part_sources ( i )->_rvc()._dsize() == 0 ); // y_i is empty 
    7376    } 
    7477 
    7578    if ( OK ) { 
    76         mat lW = zeros ( pdfs.length(), eSmp._w().length() ); 
     79        mat lW = zeros ( part_sources.length(), eSmp._w().length() ); 
    7780 
    7881        vec emptyvec ( 0 ); 
    79         for ( int i = 0; i < pdfs.length(); i++ ) { 
     82        for ( int i = 0; i < part_sources.length(); i++ ) { 
    8083            for ( int j = 0; j < eSmp._w().length(); j++ ) { 
    81                 lW ( i, j ) = pdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec ); 
     84                lW ( i, j ) = part_sources ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec ); 
    8285            } 
    8386        } 
     
    9295} 
    9396 
    94 vec merger_base::merge_points ( mat &lW ) { 
     97vec MergerDiscrete::merge_points ( mat &lW ) { 
    9598    int nu = lW.rows(); 
    9699    vec result; 
     
    144147} 
    145148 
    146 vec merger_base::mean() const { 
    147     const Vec<double> &w = eSmp._w(); 
    148     const Array<vec> &S = eSmp._samples(); 
    149     vec tmp = zeros ( dim ); 
    150     for ( int i = 0; i < Npoints; i++ ) { 
    151         tmp += w ( i ) * S ( i ); 
    152     } 
    153     return tmp; 
    154 } 
    155  
    156 mat merger_base::covariance() const { 
    157     const vec &w = eSmp._w(); 
    158     const Array<vec> &S = eSmp._samples(); 
    159  
    160     vec mea = mean(); 
    161  
    162 //             cout << sum (w) << "," << w*w << endl; 
    163  
    164     mat Tmp = zeros ( dim, dim ); 
    165     for ( int i = 0; i < Npoints; i++ ) { 
    166         vec tmp=S ( i )-mea; //inefficient but numerically stable 
    167         Tmp += w ( i ) * outer_product (tmp , tmp ); 
    168     } 
    169     return Tmp; 
    170 } 
    171  
    172 vec merger_base::variance() const { 
    173     return eSmp.variance(); 
    174 } 
    175  
    176  
    177 void merger_base::from_setting ( const Setting& set ) { 
     149void MergerDiscrete::from_setting ( const Setting& set ) { 
    178150    // get support 
    179151    // find which method to use 
    180     epdf::from_setting (set); 
    181152    string meth_str; 
    182153    UI::get( meth_str, set, "method", UI::compulsory ); 
     
    197168} 
    198169 
    199 void merger_base::to_setting  (Setting  &set) const { 
    200     epdf::to_setting(set); 
    201  
     170void MergerDiscrete::to_setting  (Setting  &set) const { 
    202171    UI::save( METHOD, set, "method"); 
    203172 
     
    209178} 
    210179 
    211 void merger_base::validate() { 
     180void MergerDiscrete::validate() { 
    212181//        bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." ); 
    213182//        bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" ); 
    214     epdf::validate(); 
    215     bdm_assert ( isnamed(), "mergers must be named" ); 
    216 } 
    217  
    218 // DEFAULTS FOR MERGER_BASE 
    219 const MERGER_METHOD merger_base::DFLT_METHOD = LOGNORMAL; 
    220 const double merger_base::DFLT_beta = 1.2; 
     183    MergerBase::validate(); 
     184    //bdm_assert ( merger().isnamed(), "mergers must be named" ); 
     185} 
    221186 
    222187void merger_mix::merge ( ) { 
     188        int dim = eSmp.dimension(); 
    223189    if(Npoints<1) { 
    224190        set_support(enorm<fsqmat>(zeros(dim), eye(dim)), 1000); 
     
    231197    vec &w = eSmp._w(); //aux 
    232198 
    233     mat Smp_ex = ones ( dim + 1, Npoints ); // Extended samples for the ARX model - the last row is ones 
     199    mat Smp_ex = ones ( dim, Npoints ); // Extended samples for the ARX model 
    234200    for ( int i = 0; i < Npoints; i++ ) { 
    235201        set_col_part ( Smp_ex, i, Smp ( i ) ); 
     
    242208    vec lw_mix ( Npoints );        // weights of the approximating mixture 
    243209    vec lw ( Npoints );            // tmp 
    244     mat lW = zeros ( Nsources, Npoints ); // array of weights of all sources 
     210    mat lW = zeros ( Nsources, Npoints ); // array of weights of all part_sources 
    245211    vec vec0 ( 0 ); 
    246212 
     
    263229    char dbg_str[100]; 
    264230 
    265     emix* Mpred = Mix.epredictor ( ); 
     231        emix *Mpred = Mix.epredictor ( ); 
    266232    vec Mix_pdf ( Npoints ); 
    267233    while ( !converged ) { 
     
    272238        delete Mpred; 
    273239        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!! 
    274         Mpred->set_rv ( rv ); //the predictor predicts rv of this merger 
     240        Mpred->set_rv ( merger()._rv() ); //the predictor predicts rv of this merger 
    275241 
    276242        // This will be active only later in iterations!!! 
     
    299265            *dbg_file << Name ( dbg_str ) << Mpred->variance(); 
    300266            sprintf ( dbg_str, "Mpred_cov%d", niter ); 
    301             *dbg_file << Name ( dbg_str ) << covariance(); 
     267            *dbg_file << Name ( dbg_str ) << Mpred->covariance(); 
    302268 
    303269 
     
    313279        } 
    314280        //Importace weighting 
    315         for ( int i = 0; i < pdfs.length(); i++ ) { 
     281        for ( int i = 0; i < part_sources.length(); i++ ) { 
    316282            lw_src = 0.0; 
    317283            //======== Same RVs =========== 
    318284            //Split according to dependency in rvs 
    319             if ( pdfs ( i )->dimension() == dim ) { 
     285            if ( part_sources ( i )->dimension() == dim ) { 
    320286                // no need for conditioning or marginalization 
    321                 lw_src = pdfs ( i )->evallogcond_mat ( Smp , vec ( 0 ) ); 
     287                lw_src = part_sources ( i )->evallogcond_mat ( Smp , vec ( 0 ) ); 
    322288            } else { 
    323289                // compute likelihood of marginal on the conditional variable 
    324                 if ( pdfs ( i )->dimensionc() > 0 ) { 
     290                if ( part_sources ( i )->dimensionc() > 0 ) { 
    325291                    // Make marginal on rvc_i 
    326                     shared_ptr<epdf> tmp_marg = Mpred->marginal ( pdfs ( i )->_rvc() ); 
     292                    shared_ptr<epdf> tmp_marg = Mpred->marginal ( part_sources ( i )->_rvc() ); 
    327293                    //compute vector of lw_src 
    328294                    for ( int k = 0; k < Npoints; k++ ) { 
     
    336302                } 
    337303                // Compute likelihood of the missing variable 
    338                 if ( dim > ( pdfs ( i )->dimension() + pdfs ( i )->dimensionc() ) ) { 
     304                if ( dim > ( part_sources ( i )->dimension() + part_sources ( i )->dimensionc() ) ) { 
    339305                    /////////////// 
    340306                    // There are variales unknown to pdfs(i) : rvzs 
     
    355321                // Compute likelihood of the partial source 
    356322                for ( int k = 0; k < Npoints; k++ ) { 
    357                     lw_src ( k ) += pdfs ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ), 
     323                    lw_src ( k ) += part_sources ( i )->evallogcond ( dls ( i )->pushdown ( Smp ( k ) ), 
    358324                                    dls ( i )->get_cond ( Smp ( k ) ) ); 
    359325                } 
     
    363329            lW.set_row ( i, lw_src ); // do not divide by mix 
    364330        } 
    365         lw = merger_base::merge_points ( lW ); //merge 
     331        lw = MergerDiscrete::merge_points ( lW ); //merge 
    366332 
    367333        //Importance weighting 
     
    391357    } 
    392358    delete Mpred; 
    393 //        cout << endl; 
    394  
    395359} 
    396360 
    397361void merger_mix::from_setting ( const Setting& set ) { 
    398     merger_base::from_setting ( set ); 
     362    MergerDiscrete::from_setting ( set ); 
    399363    Ncoms=DFLT_Ncoms; 
    400364    UI::get( Ncoms, set, "ncoms", UI::optional ); 
     
    406370 
    407371void merger_mix::to_setting  (Setting  &set) const  { 
    408     merger_base::to_setting(set); 
     372    MergerDiscrete::to_setting(set); 
    409373    UI::save( Ncoms, set, "ncoms"); 
    410374    UI::save (effss_coef , set,  "effss_coef"); 
     
    413377 
    414378void merger_mix::validate() { 
    415     merger_base::validate(); 
     379    MergerDiscrete::validate(); 
    416380    bdm_assert(Ncoms>0,"Ncoms too small"); 
    417381}