Show
Ignore:
Timestamp:
03/29/10 23:01:49 (14 years ago)
Author:
smidl
Message:

epdf and emix now have _base classes

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • library/bdm/stat/emix.cpp

    r878 r886  
    33namespace bdm { 
    44 
    5 void emix::validate (){ 
    6         bdm_assert ( Coms.length() > 0, "There has to be at least one component." ); 
    7  
    8         bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." ); 
     5void emix_base::validate (){ 
     6        bdm_assert ( no_coms() > 0, "There has to be at least one component." ); 
     7 
     8        bdm_assert ( no_coms() == w.length(), "It is obligatory to define weights of all the components." ); 
    99 
    1010        double sum_w = sum ( w ); 
     
    1212        w = w / sum_w; 
    1313 
    14         dim = Coms ( 0 )->dimension(); 
    15         RV rv_tmp = Coms ( 0 )->_rv() ; 
    16         bool isnamed = Coms ( 0 )->isnamed(); 
    17         for ( int i = 1; i < Coms.length(); i++ ) { 
    18                 bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" ); 
    19                 isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp); 
     14        dim = component ( 0 )->dimension(); 
     15        RV rv_tmp = component ( 0 )->_rv() ; 
     16        bool isnamed = component( 0 )->isnamed(); 
     17        for ( int i = 1; i < no_coms(); i++ ) { 
     18                bdm_assert ( dim == ( component ( i )->dimension() ), "Component sizes do not match!" ); 
     19                isnamed &= component(i)->isnamed() & component(i)->_rv().equal(rv_tmp); 
    2020        } 
    2121        if (isnamed) 
     
    2323} 
    2424 
    25 void emix::from_setting ( const Setting &set ) { 
    26         UI::get ( Coms, set, "pdfs", UI::compulsory ); 
    27  
    28         if ( !UI::get ( w, set, "weights", UI::optional ) ) { 
    29                 int len = Coms.length(); 
    30                 w.set_length ( len ); 
    31                 w = 1.0 / len; 
    32         } 
    33 } 
    34  
    35  
    36 vec emix::sample() const { 
     25 
     26 
     27vec emix_base::sample() const { 
    3728        //Sample which component 
    3829        vec cumDist = cumsum ( w ); 
     
    4637        } 
    4738 
    48         return Coms ( i )->sample(); 
    49 } 
    50  
    51 vec emix::mean() const { 
     39        return component ( i )->sample(); 
     40} 
     41 
     42vec emix_base::mean() const { 
    5243        int i; 
    5344        vec mu = zeros ( dim ); 
    5445        for ( i = 0; i < w.length(); i++ ) { 
    55                 mu += w ( i ) * Coms ( i )->mean(); 
     46                mu += w ( i ) * component ( i )->mean(); 
    5647        } 
    5748        return mu; 
    5849} 
    5950 
    60 vec emix::variance() const { 
     51vec emix_base::variance() const { 
    6152        //non-central moment 
    6253        vec mom2 = zeros ( dim ); 
    6354        for ( int i = 0; i < w.length(); i++ ) { 
    64                 mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) ); 
     55                mom2 += w ( i ) * ( component( i )->variance() + pow ( component ( i )->mean(), 2 ) ); 
    6556        } 
    6657        //central moment 
     
    6859} 
    6960 
    70 double emix::evallog ( const vec &val ) const { 
     61double emix_base::evallog ( const vec &val ) const { 
    7162        int i; 
    7263        double sum = 0.0; 
    7364        for ( i = 0; i < w.length(); i++ ) { 
    74                 sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) ); 
     65                sum += w ( i ) * exp ( component ( i )->evallog ( val ) ); 
    7566        } 
    7667        if ( sum == 0.0 ) { 
     
    8273} 
    8374 
    84 vec emix::evallog_mat ( const mat &Val ) const { 
     75vec emix_base::evallog_mat ( const mat &Val ) const { 
    8576        vec x = zeros ( Val.cols() ); 
    8677        for ( int i = 0; i < w.length(); i++ ) { 
    87                 x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) ); 
     78                x += w ( i ) * exp ( component( i )->evallog_mat ( Val ) ); 
    8879        } 
    8980        return log ( x ); 
    9081}; 
    9182 
    92 mat emix::evallog_coms ( const mat &Val ) const { 
     83mat emix_base::evallog_coms ( const mat &Val ) const { 
    9384        mat X ( w.length(), Val.cols() ); 
    9485        for ( int i = 0; i < w.length(); i++ ) { 
    95                 X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) ); 
     86                X.set_row ( i, w ( i ) *exp ( component( i )->evallog_mat ( Val ) ) ); 
    9687        } 
    9788        return X; 
    9889} 
    9990 
    100 shared_ptr<epdf> emix::marginal ( const RV &rv ) const { 
     91shared_ptr<epdf> emix_base::marginal ( const RV &rv ) const { 
    10192        emix *tmp = new emix(); 
    10293        shared_ptr<epdf> narrow ( tmp ); 
     
    10596} 
    10697 
    107 void emix::marginal ( const RV &rv, emix &target ) const { 
     98void emix_base::marginal ( const RV &rv, emix &target ) const { 
    10899        bdm_assert ( isnamed(), "rvs are not assigned" ); 
    109100 
    110         Array<shared_ptr<epdf> > Cn ( Coms.length() ); 
    111         for ( int i = 0; i < Coms.length(); i++ ) { 
    112                 Cn ( i ) = Coms ( i )->marginal ( rv ); 
     101        Array<shared_ptr<epdf> > Cn ( no_coms() ); 
     102        for ( int i = 0; i < no_coms(); i++ ) { 
     103                Cn ( i ) = component ( i )->marginal ( rv ); 
    113104        } 
    114105 
     
    118109} 
    119110 
    120 shared_ptr<pdf> emix::condition ( const RV &rv ) const { 
     111shared_ptr<pdf> emix_base::condition ( const RV &rv ) const { 
    121112        bdm_assert ( isnamed(), "rvs are not assigned" ); 
    122113        mratio *tmp = new mratio ( this, rv ); 
     
    124115} 
    125116 
    126 void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) { 
    127         w = w0 / sum ( w0 ); 
    128         int i; 
    129         for ( i = 0; i < w.length(); i++ ) { 
    130                 bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" ); 
    131         } 
    132         if ( copy ) { 
    133                 Coms.set_length ( Coms0.length() ); 
    134                 for ( i = 0; i < w.length(); i++ ) { 
    135                         bdm_error ( "Not implemented" ); 
    136                         // *Coms ( i ) = *Coms0 ( i ); 
    137                 } 
    138                 destroyComs = true; 
    139         } else { 
    140                 Coms = Coms0; 
    141                 destroyComs = false; 
    142         } 
    143 } 
    144  
    145 void    egiwmix::validate (){ 
    146    dim = Coms ( 0 )->dimension(); 
    147 } 
    148  
    149 vec egiwmix::sample() const { 
    150         //Sample which component 
    151         vec cumDist = cumsum ( w ); 
    152         double u0; 
    153 #pragma omp critical 
    154         u0 = UniRNG.sample(); 
    155  
    156         int i = 0; 
    157         while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) { 
    158                 i++; 
    159         } 
    160  
    161         return Coms ( i )->sample(); 
    162 } 
    163  
    164 vec egiwmix::mean() const { 
    165         int i; 
    166         vec mu = zeros ( dim ); 
    167         for ( i = 0; i < w.length(); i++ ) { 
    168                 mu += w ( i ) * Coms ( i )->mean(); 
    169         } 
    170         return mu; 
    171 } 
    172  
    173 vec egiwmix::variance() const { 
    174         // non-central moment 
    175         vec mom2 = zeros ( dim ); 
    176         for ( int i = 0; i < w.length(); i++ ) { 
    177                 // pow is overloaded, we have to use another approach 
    178                 mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) ); 
    179         } 
    180         // central moment 
    181         // pow is overloaded, we have to use another approach 
    182         return mom2 - elem_mult ( mean(), mean() ); 
    183 } 
    184  
    185 shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const { 
    186         emix *tmp = new emix(); 
    187         shared_ptr<epdf> narrow ( tmp ); 
    188         marginal ( rv, *tmp ); 
    189         return narrow; 
    190 } 
    191  
    192 void egiwmix::marginal ( const RV &rv, emix &target ) const { 
    193         bdm_assert_debug ( isnamed(), "rvs are not assigned" ); 
    194  
    195         Array<shared_ptr<epdf> > Cn ( Coms.length() ); 
    196         for ( int i = 0; i < Coms.length(); i++ ) { 
    197                 Cn ( i ) = Coms ( i )->marginal ( rv ); 
    198         } 
    199  
    200         target._w() = w; 
    201         target._Coms() = Cn; 
    202         target.validate(); 
    203 } 
    204  
    205 egiw*   egiwmix::approx() { 
    206         // NB: dimx == 1 !!! 
    207         // The following code might look a bit spaghetti-like, 
    208         // consult Dedecius, K. et al.: Partial forgetting in AR models. 
    209  
    210         double sumVecCommon;                            // common part for many terms in eq. 
    211         int len = w.length();                           // no. of mix components 
    212         int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS 
    213         vec vecNu ( len );                                      // vector of dfms of components 
    214         vec vecD ( len );                                       // vector of LS reminders of comps. 
    215         vec vecCommon ( len );                          // vector of common parts 
    216         mat matVecsTheta;                               // matrix which rows are theta vects. 
    217  
    218         // fill in the vectors vecNu, vecD and matVecsTheta 
    219         for ( int i = 0; i < len; i++ ) { 
    220                 vecNu.shift_left ( Coms ( i )->_nu() ); 
    221                 vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) ); 
    222                 matVecsTheta.append_row ( Coms ( i )->est_theta() ); 
    223         } 
    224  
    225         // calculate the common parts and their sum 
    226         vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) ); 
    227         sumVecCommon = sum ( vecCommon ); 
    228  
    229         // LS estimator of theta 
    230         vec aprEstTheta ( dimLS ); 
    231         aprEstTheta.zeros(); 
    232         for ( int i = 0; i < len; i++ ) { 
    233                 aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i ); 
    234         } 
    235         aprEstTheta /= sumVecCommon; 
    236  
    237  
    238         // LS estimator of dfm 
    239         double aprNu; 
    240         double A = log ( sumVecCommon );                // Term 'A' in equation 
    241  
    242         for ( int i = 0; i < len; i++ ) { 
    243                 A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) ); 
    244         } 
    245  
    246         aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) ); 
    247  
    248  
    249         // LS reminder (term D(0,0) in C-syntax) 
    250         double aprD = aprNu / sumVecCommon; 
    251  
    252         // Aproximation of cov 
    253         // the following code is very numerically sensitive, thus 
    254         // we have to eliminate decompositions etc. as much as possible 
    255         mat aprC = zeros ( dimLS, dimLS ); 
    256         for ( int i = 0; i < len; i++ ) { 
    257                 aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i ); 
    258                 vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta ); 
    259                 aprC += vecCommon ( i ) * outer_product ( tmp, tmp ); 
    260         } 
    261  
    262         // Construct GiW pdf :: BEGIN 
    263         ldmat aprCinv ( inv ( aprC ) ); 
    264         vec D = concat ( aprD, aprCinv._D() ); 
    265         mat L = eye ( dimLS + 1 ); 
    266         L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta ); 
    267         L.set_submatrix ( 1, 1, aprCinv._L() ); 
    268         ldmat aprLD ( L, D ); 
    269  
    270         egiw* aprgiw = new egiw ( 1, aprLD, aprNu ); 
    271         return aprgiw; 
    272 }; 
     117void emix::from_setting ( const Setting &set ) { 
     118        UI::get ( Coms, set, "pdfs", UI::compulsory ); 
     119        UI::get ( w, set, "weights", UI::compulsory ); 
     120} 
     121 
     122void    emix::validate (){ 
     123        emix_base::validate(); 
     124        dim = Coms ( 0 )->dimension(); 
     125} 
     126 
    273127 
    274128double mprod::evallogcond ( const vec &val, const vec &cond ) { 
     
    378232} 
    379233 
    380 vec eprod::mean() const { 
     234vec eprod_base::mean() const { 
    381235        vec tmp ( dim ); 
    382         for ( int i = 0; i < epdfs.length(); i++ ) { 
    383                 vec pom = epdfs ( i )->mean(); 
     236        for ( int i = 0; i < no_factors(); i++ ) { 
     237                vec pom = factor( i )->mean(); 
    384238                dls ( i )->pushup ( tmp, pom ); 
    385239        } 
     
    387241} 
    388242 
    389 vec eprod::variance() const { 
     243vec eprod_base::variance() const { 
    390244        vec tmp ( dim ); //second moment 
    391         for ( int i = 0; i < epdfs.length(); i++ ) { 
    392                 vec pom = epdfs ( i )->mean(); 
     245        for ( int i = 0; i < no_factors(); i++ ) { 
     246                vec pom = factor ( i )->mean(); 
    393247                dls ( i )->pushup ( tmp, pow ( pom, 2 ) ); 
    394248        } 
    395249        return tmp - pow ( mean(), 2 ); 
    396250} 
    397 vec eprod::sample() const { 
     251vec eprod_base::sample() const { 
    398252        vec tmp ( dim ); 
    399         for ( int i = 0; i < epdfs.length(); i++ ) { 
    400                 vec pom = epdfs ( i )->sample(); 
     253        for ( int i = 0; i < no_factors(); i++ ) { 
     254                vec pom = factor ( i )->sample(); 
    401255                dls ( i )->pushup ( tmp, pom ); 
    402256        } 
    403257        return tmp; 
    404258} 
    405 double eprod::evallog ( const vec &val ) const { 
     259double eprod_base::evallog ( const vec &val ) const { 
    406260        double tmp = 0; 
    407         for ( int i = 0; i < epdfs.length(); i++ ) { 
    408                 tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) ); 
    409         } 
    410         bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" ); 
    411         return tmp; 
    412 } 
    413  
    414 } 
    415 // mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) { 
    416 //              int i; 
    417 //              bool rvaddok; 
    418 //              // Create rv 
    419 //              for ( i = 0;i < n;i++ ) { 
    420 //                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs. 
    421 //                      // If rvaddok==false, pdfs overlap => assert error. 
    422 //                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf 
    423 //              }; 
    424 //              // Create rvc 
    425 //              for ( i = 0;i < n;i++ ) { 
    426 //                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs. 
    427 //              }; 
    428 // 
    429 // //           independent = true; 
    430 //              //test rvc of pdfs and fill rvinds 
    431 //              for ( i = 0;i < n;i++ ) { 
    432 //                      // find ith rv in common rv 
    433 //                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv ); 
    434 //                      // find ith rvc in common rv 
    435 //                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv ); 
    436 //                      // find ith rvc in common rv 
    437 //                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc ); 
    438 //                      // 
    439 // /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;} 
    440 //                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/ 
    441 //              } 
    442 //      }; 
    443  
     261        for ( int i = 0; i < no_factors(); i++ ) { 
     262                tmp += factor ( i )->evallog ( dls ( i )->pushdown ( val ) ); 
     263        } 
     264        //bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" ); 
     265        return tmp; 
     266} 
     267 
     268} 
     269