Changeset 170 for bdm/stat/libEF.h

Show
Ignore:
Timestamp:
09/24/08 13:07:50 (16 years ago)
Author:
smidl
Message:

Mixtures of EF and related changes to libEF and BM

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • bdm/stat/libEF.h

    r168 r170  
    4141        //! default constructor 
    4242        eEF ( const RV &rv ) :epdf ( rv ) {}; 
    43         //! logarithm of the normalizing constant, \f$\mathcal{I}\f$  
    44         virtual double lognc()const =0; 
     43        //! logarithm of the normalizing constant, \f$\mathcal{I}\f$ 
     44        virtual double lognc() const =0; 
    4545        //!TODO decide if it is really needed 
    46         virtual void tupdate ( double phi, mat &vbar, double nubar ) {}; 
    47         //!TODO decide if it is really needed 
    48         virtual void dupdate ( mat &v,double nu=1.0 ) {}; 
     46        virtual void dupdate ( mat &v ) {it_error ( "Not implemneted" );}; 
     47        //!Evaluate normalized log-probability 
     48        virtual double evalpdflog_nn ( const vec &val ) const{it_error ( "Not implemneted" );return 0.0;}; 
     49        //!Evaluate normalized log-probability 
     50        virtual double evalpdflog ( const vec &val ) const {return evalpdflog_nn ( val )-lognc();} 
     51        //!Evaluate normalized log-probability for many samples 
     52        virtual vec evalpdflog ( const mat &Val ) const { 
     53                vec x ( Val.cols() ); 
     54                for ( int i=0;i<Val.cols();i++ ) {x ( i ) =evalpdflog_nn ( Val.get_col ( i ) ) ;} 
     55                return x-lognc(); 
     56        } 
     57        //!Power of the density, used e.g. to flatten the density 
     58        virtual void pow ( double p ) {it_error ( "Not implemented" );}; 
    4959}; 
    5060 
     
    6070        //! Default constructor 
    6171        mEF ( const RV &rv0, const RV &rvc0 ) :mpdf ( rv0,rvc0 ) {}; 
     72}; 
     73 
     74//! Estimator for Exponential family 
     75class BMEF : public BM { 
     76protected: 
     77        //! forgetting factor 
     78        double frg; 
     79        //! cached value of lognc() in the previous step (used in evaluation of \c ll ) 
     80        double last_lognc; 
     81public: 
     82        //! Default constructor 
     83        BMEF ( const RV &rv, double frg0=1.0 ) :BM ( rv ), frg ( frg0 ) {} 
     84        //! Copy constructor 
     85        BMEF ( const BMEF &B ) :BM ( B ), frg ( B.frg ), last_lognc ( B.last_lognc ) {} 
     86        //!get statistics from another model 
     87        virtual void set_statistics ( const BMEF* BM0 ) {it_error ( "Not implemented" );}; 
     88        //! Weighted update of sufficient statistics (Bayes rule) 
     89        virtual void bayes ( const vec &data, const double w ) {}; 
     90        //original Bayes 
     91        void bayes ( const vec &dt ); 
     92        //!Flatten the posterior 
     93        virtual void flatten ( BMEF * B) {it_error ( "Not implemented" );} 
    6294}; 
    6395 
     
    99131        //! returns a pointer to the internal mean value. Use with Care! 
    100132        vec& _mu() {return mu;} 
    101          
     133 
    102134        //! access function 
    103         void set_mu(const vec mu0) { mu=mu0;} 
     135        void set_mu ( const vec mu0 ) { mu=mu0;} 
    104136 
    105137        //! returns pointers to the internal variance and its inverse. Use with Care! 
     
    128160public: 
    129161        //!Default constructor, assuming 
    130         egiw(RV rv, mat V0, double nu0): eEF(rv), V(V0), nu(nu0) { 
    131                 xdim = rv.count()/V.rows(); 
    132                 it_assert_debug(rv.count()==xdim*V.rows(),"Incompatible V0."); 
     162        egiw ( RV rv, mat V0, double nu0 ) : eEF ( rv ), V ( V0 ), nu ( nu0 ) { 
     163                xdim = rv.count() /V.rows(); 
     164                it_assert_debug ( rv.count() ==xdim*V.rows(),"Incompatible V0." ); 
     165                nPsi = V.rows()-xdim; 
     166        } 
     167        //!Full constructor for V in ldmat form 
     168        egiw ( RV rv, ldmat V0, double nu0 ) : eEF ( rv ), V ( V0 ), nu ( nu0 ) { 
     169                xdim = rv.count() /V.rows(); 
     170                it_assert_debug ( rv.count() ==xdim*V.rows(),"Incompatible V0." ); 
    133171                nPsi = V.rows()-xdim; 
    134172        } 
     
    136174        vec sample() const; 
    137175        vec mean() const; 
     176        void mean_mat ( mat &M, mat&R ) const; 
    138177        //! In this instance, val= [theta, r]. For multivariate instances, it is stored columnwise val = [theta_1 theta_2 ... r_1 r_2 ] 
    139         double evalpdflog ( const vec &val ) const; 
     178        double evalpdflog_nn ( const vec &val ) const; 
    140179        double lognc () const; 
    141180 
     
    145184        //! returns a pointer to the internal statistics. Use with Care! 
    146185        double& _nu() {return nu;} 
    147  
     186        void pow ( double p ); 
     187}; 
     188 
     189/*! \brief Dirichlet posterior density 
     190Continuous Dirichlet density of \f$n\f$-dimensional variable \f$x\f$ 
     191\f[ 
     192f(x|\beta) = \frac{\Gamma[\gamma]}{\prod_{i=1}^{n}\Gamma(\beta_i)} \prod_{i=1}^{n]x_i^(\beta_i-1) 
     193\f] 
     194where \f$\gamma=\sum_i beta_i\f$. 
     195*/ 
     196class eDirich: public eEF { 
     197protected: 
     198        //!sufficient statistics 
     199        vec beta; 
     200public: 
     201        //!Default constructor 
     202        eDirich ( const RV &rv, const vec &beta0 ) : eEF ( rv ),beta ( beta0 ) {it_assert_debug ( rv.count() ==beta.length(),"Incompatible statistics" ); }; 
     203        //! Copy constructor 
     204        eDirich ( const eDirich &D0 ) : eEF ( D0.rv ),beta ( D0.beta ) {}; 
     205        vec sample() const {it_error ( "Not implemented" );return vec_1 ( 0.0 );}; 
     206        vec mean() const {return beta/sum ( beta );}; 
     207        //! In this instance, val= [theta, r]. For multivariate instances, it is stored columnwise val = [theta_1 theta_2 ... r_1 r_2 ] 
     208        double evalpdflog_nn ( const vec &val ) const {return ( beta-1 ) *log ( val );}; 
     209        double lognc () const { 
     210                double gam=sum ( beta ); 
     211                double lgb=0.0; 
     212                for ( int i=0;i<beta.length();i++ ) {lgb+=lgamma ( beta ( i ) );} 
     213                return lgb-lgamma ( gam ); 
     214        }; 
     215        //!access function 
     216        vec& _beta() {return beta;} 
     217}; 
     218 
     219//! Estimator for Multinomial density 
     220class multiBM : public BMEF { 
     221protected: 
     222        //! Conjugate prior and posterior 
     223        eDirich est; 
     224        vec &beta; 
     225public: 
     226        //!Default constructor 
     227        multiBM ( const RV &rv, const vec beta0 ) : BMEF ( rv ),est ( rv,beta0 ),beta ( est._beta() ) {last_lognc=est.lognc();} 
     228        //!Copy constructor 
     229        multiBM ( const multiBM &B ) : BMEF ( B ),est ( rv,B.beta ),beta ( est._beta() ) {} 
     230 
     231        void set_statistics ( const BM* mB0 ) {const multiBM* mB=dynamic_cast<const multiBM*> ( mB0 ); beta=mB->beta;} 
     232        void bayes ( const vec &dt ) { 
     233                if ( frg<1.0 ) {beta*=frg;last_lognc=est.lognc();} 
     234                beta+=dt; 
     235                if ( evalll ) {ll=est.lognc()-last_lognc;} 
     236        } 
     237        double logpred ( const vec &dt ) const { 
     238                eDirich pred ( est ); 
     239                vec &beta = pred._beta(); 
     240                 
     241                double lll; 
     242                if ( frg<1.0 ) 
     243                        {beta*=frg;lll=pred.lognc();} 
     244                else 
     245                        if ( evalll ) {lll=last_lognc;} 
     246                        else{lll=pred.lognc();} 
     247 
     248                beta+=dt; 
     249                return pred.lognc()-lll; 
     250        } 
     251        void flatten (BMEF* B ) { 
     252                eDirich* E=dynamic_cast<eDirich*>(B); 
     253                // sum(beta) should be equal to sum(B.beta) 
     254                const vec &Eb=E->_beta(); 
     255                est.pow ( sum(beta)/sum(Eb) ); 
     256                if(evalll){last_lognc=est.lognc();} 
     257        } 
     258        const epdf& _epdf() const {return est;}; 
     259        //!access funct 
    148260}; 
    149261 
     
    151263 \brief Gamma posterior density 
    152264 
    153  Multvariate Gamma density as product of independent univariate densities. 
     265 Multivariate Gamma density as product of independent univariate densities. 
    154266 \f[ 
    155267 f(x|\alpha,\beta) = \prod f(x_i|\alpha_i,\beta_i) 
     
    213325        double evalpdflog ( const vec &val ) const  {return lnk;} 
    214326        vec sample() const { 
    215                 vec smp ( rv.count() );  
    216                 #pragma omp critical 
     327                vec smp ( rv.count() ); 
     328#pragma omp critical 
    217329                UniRNG.sample_vector ( rv.count(),smp ); 
    218                 return low+elem_mult(distance,smp); 
     330                return low+elem_mult ( distance,smp ); 
    219331        } 
    220332        //! set values of \c low and \c high 
     
    408520double enorm<sq_T>::evalpdflog ( const vec &val ) const { 
    409521        // 1.83787706640935 = log(2pi) 
    410         return  -0.5* (  +R.invqform ( mu-val ) ) - lognc(); 
     522        return  -0.5* ( +R.invqform ( mu-val ) ) - lognc(); 
    411523}; 
    412524 
     
    414526inline double enorm<sq_T>::lognc () const { 
    415527        // 1.83787706640935 = log(2pi) 
    416         return -0.5* ( R.cols() * 1.83787706640935 +R.logdet()); 
    417 }; 
    418  
    419 template<class sq_T> 
    420 mlnorm<sq_T>::mlnorm ( RV &rv0,RV &rvc0 ) :mEF ( rv0,rvc0 ),epdf ( rv0 ),A ( rv0.count(),rv0.count() ),_mu(epdf._mu()) { ep =&epdf; 
     528        return -0.5* ( R.cols() * 1.83787706640935 +R.logdet() ); 
     529}; 
     530 
     531template<class sq_T> 
     532mlnorm<sq_T>::mlnorm ( RV &rv0,RV &rvc0 ) :mEF ( rv0,rvc0 ),epdf ( rv0 ),A ( rv0.count(),rv0.count() ),_mu ( epdf._mu() ) { 
     533        ep =&epdf; 
    421534} 
    422535