Changeset 189 for bdm/estim/mixef.cpp

Show
Ignore:
Timestamp:
10/22/08 10:46:36 (16 years ago)
Author:
smidl
Message:

extend MixEF to allow for EM algorithm and alow estimation of weighted empirical density

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • bdm/estim/mixef.cpp

    r180 r189  
    2121        Coms ( 0 )->flatten ( Com0 ); 
    2222 
    23         //Copy it to the rest  
     23        //Copy it to the rest 
    2424        for ( i=1;i<n;i++ ) { 
    2525                //copy Com0 and create new rvs for them 
     
    3838 
    3939} 
    40 void MixEF::bayesB ( const mat &Data ) { 
    41         this->bayes ( Data ); 
    42 } 
    4340 
    44 void MixEF::bayes ( const vec &data ) { 
    45  
    46 }; 
    47  
    48 void MixEF::bayes ( const mat &data ) { 
     41void MixEF::bayesB ( const mat &data , const vec &wData ) { 
    4942        int ndat=data.cols(); 
    5043        int t,i,niter; 
     
    6356        // tmp for weights 
    6457        vec wtmp = zeros ( n ); 
     58        int maxi; 
     59        double maxll; 
    6560        //Estim 
    6661        while ( !converged ) { 
     
    7570                                ll ( i ) += weights.logpred ( wtmp ); 
    7671                        } 
    77                         w = exp ( ll-max ( ll ) ); 
    78                         W.set_col ( t, w/sum ( w ) ); 
     72                         
     73                        maxll = max(ll,maxi); 
     74                        switch (method) { 
     75                                case QB: 
     76                                        w = exp ( ll-maxll ); 
     77                                        w/=sum(w); 
     78                                        break; 
     79                                case EM: 
     80                                        w = 0.0; 
     81                                        w(maxi) = 1.0; 
     82                                        break; 
     83                        } 
     84                         
     85                        W.set_col ( t, w ); 
    7986                } 
    8087 
     88                // copy initial statistics 
    8189                for ( i=0;i<n;i++ ) { 
    8290                        Coms ( i )-> set_statistics ( Coms0 ( i ) ); 
     
    8492                weights.set_statistics ( &weights0 ); 
    8593 
     94                // Update statistics 
     95                // !!!!    note  wData ==> this is extra weight of the data record 
     96                // !!!!    For typical cases wData=1. 
    8697                for ( t=0;t<ndat;t++ ) { 
    8798                        for ( i=0;i<n;i++ ) { 
    88                                 Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) ); 
     99                                Coms ( i )-> bayes ( data.get_col ( t ),W ( i,t ) * wData ( t ) ); 
    89100                        } 
    90                         weights.bayes ( W.get_col ( t ) ); 
     101                        weights.bayes ( W.get_col ( t ) * wData ( t ) ); 
    91102                } 
    92103 
     
    98109        //Clean Coms0 
    99110        for ( i=0;i<n;i++ ) {delete Coms0 ( i );} 
     111} 
     112 
     113void MixEF::bayes ( const vec &data ) { 
     114 
     115}; 
     116 
     117void MixEF::bayes ( const mat &data ) { 
     118        this->bayesB ( data, ones ( data.cols() ) ); 
    100119}; 
    101120 
     
    111130} 
    112131 
    113 emix* MixEF::predictor(const RV &rv){ 
    114         Array<epdf*> pC(n); 
    115         for(int i=0;i<n;i++){pC(i)=Coms(i)->predictor(rv);} 
     132emix* MixEF::predictor ( const RV &rv ) { 
     133        Array<epdf*> pC ( n ); 
     134        for ( int i=0;i<n;i++ ) {pC ( i ) =Coms ( i )->predictor ( rv );} 
    116135        emix* tmp; 
    117         tmp = new emix(rv); 
    118         tmp->set_parameters(weights._epdf().mean(), pC, false); 
     136        tmp = new emix ( rv ); 
     137        tmp->set_parameters ( weights._epdf().mean(), pC, false ); 
    119138        tmp->ownComs(); 
    120139        return tmp;