Changeset 725

Show
Ignore:
Timestamp:
11/17/09 00:53:54 (14 years ago)
Author:
smidl
Message:

Sampling from egiw

Location:
library
Files:
7 modified

Legend:

Unmodified
Added
Removed
  • library/bdm/math/chmat.h

    r723 r725  
    5757        void inv ( chmat &Inv ) const   { 
    5858                ( Inv.Ch ) = itpp::inv ( Ch ).T(); 
     59                Inv.dim = dim; 
    5960        }; //Fixme: can be more efficient 
    6061        ; 
  • library/bdm/math/square_mat.cpp

    r565 r725  
    115115                } 
    116116        } 
    117         mat V2 = L.transpose() * diag ( D ) * L; 
    118         return V2; 
     117        //mat V2 = L.transpose() * diag ( D ) * L; 
     118        return V; 
    119119} 
    120120 
  • library/bdm/stat/exp_family.cpp

    r679 r725  
    3333 
    3434vec egiw::sample() const { 
    35         bdm_warning ( "Function not implemented" ); 
    36         return vec_1 ( 0.0 ); 
     35        mat M; 
     36        chmat R; 
     37        sample_mat(M,R); 
     38         
     39        return concat (cvectorize(M),cvectorize(R.to_mat())); 
     40} 
     41 
     42void egiw::sample_mat(mat &Mi, chmat &Ri)const{ 
     43         
     44        // TODO - correct approach - convert to product of norm * Wishart 
     45        mat M; 
     46        ldmat Vz; 
     47        ldmat Lam; 
     48        factorize(M,Vz,Lam); 
     49         
     50        chmat Ch; 
     51        Ch.setCh(Lam._L()*diag(sqrt(Lam._D()))); 
     52        chmat iCh; 
     53        Ch.inv(iCh); 
     54         
     55        eWishartCh Omega; //inverse Wishart, result is R,  
     56        Omega.set_parameters(iCh,nu-2*nPsi-dimx); // 2*nPsi is there to match numercial simulations - check if analytically correct 
     57         
     58        chmat Omi; 
     59        Omi.setCh(Omega.sample_mat()); 
     60         
     61        mat Z=randn(M.rows(), M.cols()); 
     62        Mi = M+ Omi._Ch() * Z * inv(Vz._L()*diag(sqrt(Vz._D()))); 
     63        Omi.inv(Ri); 
    3764} 
    3865 
     
    104131} 
    105132 
     133void egiw::factorize(mat &M, ldmat &Vz, ldmat &Lam) const{       
     134        const mat &L = V._L(); 
     135        const vec &D = V._D(); 
     136        int end = L.rows() - 1; 
     137         
     138        Vz=ldmat(L ( dimx, end, dimx, end ), D(dimx,end)); 
     139        mat iLsub = ltuinv ( Vz._L()); 
     140        // set mean value 
     141        mat Lpsi = L ( dimx, end, 0, dimx - 1 ); 
     142        M = iLsub * Lpsi; 
     143         
     144        Lam =ldmat ( L (0, dimx-1, 0, dimx-1 ), D (0, dimx-1 ) );  //exp val of R 
     145         if (1){ // test with Peterka 
     146                 mat VF=V.to_mat(); 
     147                 mat Vf=VF(0,dimx-1,0, dimx-1); 
     148                 mat Vzf = VF(dimx,end,0,dimx-1); 
     149                 mat VZ = VF(dimx,end,dimx,end); 
     150                  
     151                 mat Lam2 = Vf-Vzf.T()*inv(VZ)*Vzf; 
     152         } 
     153} 
     154 
    106155ldmat egiw::est_theta_cov() const { 
    107156        if ( dimx == 1 ) { 
     
    111160 
    112161                mat Lsub = L ( 1, end, 1, end ); 
    113                 mat Dsub = diag ( D ( 1, end ) ); 
    114  
    115                 return inv ( transpose ( Lsub ) * Dsub * Lsub ); 
     162//              mat Dsub = diag ( D ( 1, end ) ); 
     163 
     164                ldmat LD(inv(Lsub).T(), 1.0/D(1,end)); 
     165                return LD; 
    116166 
    117167        } else { 
  • library/bdm/stat/exp_family.h

    r723 r725  
    223223                vec mean() const; 
    224224                vec variance() const; 
    225  
     225                void sample_mat(mat &Mi, chmat &Ri)const; 
     226                         
     227                void factorize(mat &M, ldmat &Vz, ldmat &Lam) const; 
    226228                //! LS estimate of \f$\theta\f$ 
    227229                vec est_theta() const; 
     
    245247                const double& _nu() const {return nu;} 
    246248                const int & _dimx() const {return dimx;} 
     249                         
    247250                /*! Create Gauss-inverse-Wishart density  
    248251                \f[ f(rv) = GiW(V,\nu) \f] 
     
    259262                \endcode 
    260263                */ 
    261                  
     264                                 
    262265                void from_setting (const Setting &set) { 
    263266                        epdf::from_setting(set); 
     
    274277                        } 
    275278                } 
     279                 
     280                void to_setting ( Setting& set ) const{ 
     281                        epdf::to_setting(set); 
     282                        UI::save(dimx,set,"dimx"); 
     283                        UI::save(V.to_mat(),set,"V"); 
     284                        UI::save(nu,set,"nu");                   
     285                }; 
     286                 
    276287                void validate(){ 
    277288                        // check sizes, rvs etc. 
     289                } 
     290                void log_register( bdm::logger& L, const string& prefix ){ 
     291                        if (log_level==3){ 
     292                                root::log_register(L,prefix); 
     293                                logrec->ids.set_length(2); 
     294                                int th_dim=dimension()-dimx*(dimx+1)/2; 
     295                                logrec->ids(0)=L.add(RV("",th_dim), prefix + logrec->L.prefix_sep() +"mean"); 
     296                                logrec->ids(1)=L.add(RV("",th_dim*th_dim),prefix + logrec->L.prefix_sep() + "variance");  
     297                        } else { 
     298                                epdf::log_register(L,prefix); 
     299                        } 
     300                } 
     301                void log_write() const { 
     302                        if (log_level==3){ 
     303                                mat M; 
     304                                ldmat Lam; 
     305                                ldmat Vz; 
     306                                factorize(M,Vz,Lam); 
     307                                logrec->L.logit(logrec->ids(0), est_theta() ); 
     308                                logrec->L.logit(logrec->ids(1), cvectorize(est_theta_cov().to_mat())); 
     309                        } else { 
     310                                epdf::log_write(); 
     311                        } 
     312                         
    278313                } 
    279314                //!@} 
     
    11381173                //! Set internal structures 
    11391174                void set_parameters (const mat &Y0, const double delta0) {Y = chmat (Y0);delta = delta0; p = Y.rows(); dim = p * p; } 
     1175                //! Set internal structures 
     1176                void set_parameters (const chmat &Y0, const double delta0) {Y = Y0;delta = delta0; p = Y.rows(); dim = p * p; } 
    11401177                //! Sample matrix argument 
    11411178                mat sample_mat() const { 
  • library/tests/epdf_harness.cpp

    r722 r725  
    5959        } 
    6060 
    61         if ( R.rows() > 0 ) { 
     61        if ( variance.length() > 0 ) { 
    6262                check_sample_mean(); 
     63        } 
     64        if (R.rows()>0){ 
    6365                check_covariance(); 
    6466        } 
     
    114116        vec yb = support.get_row ( 1 ); 
    115117 
    116         int tc = 0; 
    117         Array<vec> actual ( CurrentContext::max_trial_count ); 
    118         do { 
    119                 vec emu = num_mean2 ( hepdf.get(), xb, yb, nbins ( 0 ), nbins ( 1 ) ); 
    120                 actual ( tc ) = emu; 
    121                 ++tc; 
    122         } while ( ( tc < CurrentContext::max_trial_count ) && 
    123                   !UnitTest::AreClose ( mean, actual ( tc - 1 ), tolerance ) ); 
    124         if ( ( tc == CurrentContext::max_trial_count ) && 
    125                 ( !UnitTest::AreClose ( mean, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { 
    126                 UnitTest::MemoryOutStream stream; 
    127                 stream << CurrentContext::format_context ( __LINE__ ) << "expected " << mean << " +/- " << tolerance << " but was " << actual; 
    128  
    129                 UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), 0, false ); 
    130  
    131                 UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); 
    132         } 
     118        vec actual; 
     119        actual = num_mean2 ( hepdf.get(), xb, yb, nbins ( 0 ), nbins ( 1 ) ); 
     120 
     121        CHECK_CLOSE(mean, actual, tolerance); 
    133122} 
    134123 
     
    137126        vec yb = support.get_row ( 1 ); 
    138127 
    139         int tc = 0; 
    140         Array<double> actual ( CurrentContext::max_trial_count ); 
    141         do { 
    142                 double nc = normcoef ( hepdf.get(), xb, yb, nbins ( 0 ), nbins ( 1 ) ); 
    143                 actual ( tc ) = nc; 
    144                 ++tc; 
    145         } while ( ( tc < CurrentContext::max_trial_count ) && 
    146                   !UnitTest::AreClose ( 1.0, actual ( tc - 1 ), tolerance ) ); 
    147         if ( ( tc == CurrentContext::max_trial_count ) && 
    148                 ( !UnitTest::AreClose ( 1.0, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { 
    149                 UnitTest::MemoryOutStream stream; 
    150                 stream << CurrentContext::format_context ( __LINE__ ) << "expected " << mean << " +/- " << tolerance << " but was " << actual; 
    151  
    152                 UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), 0, false ); 
    153  
    154                 UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); 
    155         } 
     128        double nc = normcoef ( hepdf.get(), xb, yb, nbins ( 0 ), nbins ( 1 ) ); 
     129        CHECK_CLOSE(1.0,nc,0.01); 
    156130} 
    157131 
     
    168142        } while ( ( tc < CurrentContext::max_trial_count ) && 
    169143                  !UnitTest::AreClose ( mean, actual ( tc - 1 ), delta ) ); 
     144                           
    170145        if ( ( tc == CurrentContext::max_trial_count ) && 
    171146                ( !UnitTest::AreClose ( mean, actual ( CurrentContext::max_trial_count - 1 ), delta ) ) ) { 
  • library/tests/testsuite/egiw.cfg

    r717 r725  
    1212    }; 
    1313  }; 
    14   mean = [ 1.1, 0.1 ]; 
    15   variance = [ 0.01, 8e-05 ]; 
     14  mean = [ 1.1, 0.2 ]; 
     15  variance = [ 0.02, 0.01333 ]; 
    1616  support = ( "matrix", 2, 2, [ -2.0, 4.0, 0.01, 2.0 ] ); 
    1717  nbins = [ 100, 200 ]; 
    18   tolerance = 0.2; 
     18  tolerance = 0.01; 
    1919}, 
    2020{ 
     
    3333  mean = [2.0, 1.14286]; 
    3434  variance = [0.0285714, 0.0395795]; 
     35  support = ( "matrix", 2, 2, [ 0.0, 4.0, 0.01, 3.0 ] ); 
     36  nbins = [ 100, 200 ]; 
     37  tolerance = 0.01; 
    3538} ); 
    3639 
  • library/tests/testsuite/epdf_test.cpp

    r723 r725  
    4343 
    4444TEST ( ewishart_test ) { 
    45         mat wM = "1.0 0.9; 0.9 1.0"; 
     45        mat wM = "10.0 0.9; 0.9 1.0"; 
    4646        eWishartCh eW; 
    4747        eW.set_parameters ( wM / 100, 100 ); 
     
    5353        } 
    5454 
    55         mat observed ( "0.978486 0.88637; 0.88637 0.992141" ); 
    5655        mat actual = mea / 100; 
    57         CHECK_CLOSE ( observed, actual, 0.1 ); 
     56        CHECK_CLOSE ( wM, actual, 0.1 ); 
    5857} 
    5958