| 1 | #include "stat/exp_family.h" |
|---|
| 2 | #include "stat/emix.h" |
|---|
| 3 | #include "../mat_checks.h" |
|---|
| 4 | #include "UnitTest++.h" |
|---|
| 5 | #include "../test_util.h" |
|---|
| 6 | #include "../pdf_harness.h" |
|---|
| 7 | |
|---|
| 8 | |
|---|
| 9 | const double epsilon = 0.00001; |
|---|
| 10 | |
|---|
| 11 | using namespace bdm; |
|---|
| 12 | |
|---|
| 13 | static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ); |
|---|
| 14 | |
|---|
| 15 | static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance ); |
|---|
| 16 | |
|---|
| 17 | TEST ( emix_test ) { |
|---|
| 18 | pdf_harness::test_config ( "emix.cfg" ); |
|---|
| 19 | } |
|---|
| 20 | |
|---|
| 21 | TEST ( emix_1_test ) { |
|---|
| 22 | RV x ( "{emixx }" ); |
|---|
| 23 | RV y ( "{emixy }" ); |
|---|
| 24 | RV xy = concat ( x, y ); |
|---|
| 25 | vec mu0 ( "1.00054 1.0455" ); |
|---|
| 26 | |
|---|
| 27 | enorm_ldmat_ptr E1; |
|---|
| 28 | E1->set_rv ( xy ); |
|---|
| 29 | E1->set_parameters ( mu0 , mat ( "0.740142 -0.259015; -0.259015 1.0302" ) ); |
|---|
| 30 | |
|---|
| 31 | enorm_ldmat_ptr E2; |
|---|
| 32 | E2->set_rv ( xy ); |
|---|
| 33 | E2->set_parameters ( "-1.2 -0.1" , mat ( "1 0.4; 0.4 0.5" ) ); |
|---|
| 34 | |
|---|
| 35 | epdf_array A1 ( 1 ); |
|---|
| 36 | A1 ( 0 ) = E1; |
|---|
| 37 | |
|---|
| 38 | emix M1; |
|---|
| 39 | M1.set_rv ( xy ); |
|---|
| 40 | M1._Coms() = A1; |
|---|
| 41 | M1._w() = vec_1(1.0); |
|---|
| 42 | M1.validate(); |
|---|
| 43 | |
|---|
| 44 | // test if ARX and emix with one ARX are the same |
|---|
| 45 | epdf_ptr Mm = M1.marginal ( y ); |
|---|
| 46 | epdf_ptr Am = E1->marginal ( y ); |
|---|
| 47 | pdf_ptr Mc = M1.condition ( y ); |
|---|
| 48 | pdf_ptr Ac = E1->condition ( y ); |
|---|
| 49 | |
|---|
| 50 | mlnorm<ldmat> *wacnd = dynamic_cast<mlnorm<ldmat> *> ( Ac.get() ); |
|---|
| 51 | CHECK ( wacnd ); |
|---|
| 52 | if ( wacnd ) { |
|---|
| 53 | CHECK_CLOSE ( mat ( "-0.349953" ), wacnd->_A(), epsilon ); |
|---|
| 54 | CHECK_CLOSE ( vec ( "1.39564" ), wacnd->_mu_const(), epsilon ); |
|---|
| 55 | CHECK_CLOSE ( mat ( "0.939557" ), wacnd->_R(), epsilon ); |
|---|
| 56 | } |
|---|
| 57 | |
|---|
| 58 | double same = -1.46433; |
|---|
| 59 | CHECK_CLOSE ( same, Mm->evallog ( vec_1 ( 0.0 ) ), epsilon ); |
|---|
| 60 | CHECK_CLOSE ( same, Am->evallog ( vec_1 ( 0.0 ) ), epsilon ); |
|---|
| 61 | CHECK_CLOSE ( 0.145974, Mc->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon ); |
|---|
| 62 | CHECK_CLOSE ( -1.92433, Ac->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon ); |
|---|
| 63 | |
|---|
| 64 | // mixture with two components |
|---|
| 65 | epdf_array A2 ( 2 ); |
|---|
| 66 | A2 ( 0 ) = E1; |
|---|
| 67 | A2 ( 1 ) = E2; |
|---|
| 68 | |
|---|
| 69 | emix M2; |
|---|
| 70 | M2.set_rv ( xy ); |
|---|
| 71 | M2._Coms() = A2; |
|---|
| 72 | M2._w() = vec_2(.5,.5); |
|---|
| 73 | M2.validate(); |
|---|
| 74 | |
|---|
| 75 | // mixture normalization |
|---|
| 76 | CHECK_CLOSE ( 1.0, normcoef ( &M2, vec ( "-3 3 " ), vec ( "-3 3 " ) ), 0.1 ); |
|---|
| 77 | |
|---|
| 78 | int N = 6; |
|---|
| 79 | mat Smp = M2.sample_mat ( N ); |
|---|
| 80 | |
|---|
| 81 | vec exp_ll ( "-5.0 -2.53563 -2.62171 -5.0 -2.53563 -2.62171" ); |
|---|
| 82 | vec ll = M2.evallog_mat ( Smp ); |
|---|
| 83 | CHECK_CLOSE ( exp_ll, ll, 5.0 ); |
|---|
| 84 | |
|---|
| 85 | check_mean ( M2, N, 0.5*mu0+0.5*vec("-1.2 -0.1"), 1.0 ); |
|---|
| 86 | |
|---|
| 87 | mat observedR ( "0.740142 -0.259015; -0.259015 1.0302" ); |
|---|
| 88 | check_covariance ( M2, N, observedR, 2.0 ); |
|---|
| 89 | |
|---|
| 90 | epdf_ptr Mg = M2.marginal ( y ); |
|---|
| 91 | CHECK ( Mg.get() ); |
|---|
| 92 | pdf_ptr Cn = M2.condition ( x ); |
|---|
| 93 | CHECK ( Cn.get() ); |
|---|
| 94 | |
|---|
| 95 | // marginal mean |
|---|
| 96 | CHECK_CLOSE ( vec ( "0.5" ), Mg->mean(), 0.1 ); |
|---|
| 97 | } |
|---|
| 98 | |
|---|
| 99 | |
|---|
| 100 | static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ) { |
|---|
| 101 | int tc = 0; |
|---|
| 102 | Array<vec> actual ( CurrentContext::max_trial_count ); |
|---|
| 103 | do { |
|---|
| 104 | mat smp = distrib_obj.sample_mat ( nsamples ); |
|---|
| 105 | vec emu = sum ( smp, 2 ) / nsamples; |
|---|
| 106 | actual ( tc ) = emu; |
|---|
| 107 | ++tc; |
|---|
| 108 | } while ( ( tc < CurrentContext::max_trial_count ) && |
|---|
| 109 | !UnitTest::AreClose ( mean, actual ( tc - 1 ), tolerance ) ); |
|---|
| 110 | if ( ( tc == CurrentContext::max_trial_count ) && |
|---|
| 111 | ( !UnitTest::AreClose ( mean, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { |
|---|
| 112 | UnitTest::MemoryOutStream stream; |
|---|
| 113 | UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), __LINE__ ); |
|---|
| 114 | stream << "Expected " << mean << " +/- " << tolerance << " but was " << actual; |
|---|
| 115 | |
|---|
| 116 | UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); |
|---|
| 117 | } |
|---|
| 118 | } |
|---|
| 119 | |
|---|
| 120 | static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance ) { |
|---|
| 121 | int tc = 0; |
|---|
| 122 | Array<mat> actual ( CurrentContext::max_trial_count ); |
|---|
| 123 | do { |
|---|
| 124 | mat smp = distrib_obj.sample_mat ( nsamples ); |
|---|
| 125 | vec emu = sum ( smp, 2 ) / nsamples; |
|---|
| 126 | mat er = ( smp * smp.T() ) / nsamples - outer_product ( emu, emu ); |
|---|
| 127 | actual ( tc ) = er; |
|---|
| 128 | ++tc; |
|---|
| 129 | } while ( ( tc < CurrentContext::max_trial_count ) && |
|---|
| 130 | !UnitTest::AreClose ( R, actual ( tc - 1 ), tolerance ) ); |
|---|
| 131 | if ( ( tc == CurrentContext::max_trial_count ) && |
|---|
| 132 | ( !UnitTest::AreClose ( R, actual ( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { |
|---|
| 133 | UnitTest::MemoryOutStream stream; |
|---|
| 134 | UnitTest::TestDetails details ( *UnitTest::CurrentTest::Details(), __LINE__ ); |
|---|
| 135 | stream << "Expected " << R << " +/- " << tolerance << " but was " << actual; |
|---|
| 136 | |
|---|
| 137 | UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); |
|---|
| 138 | } |
|---|
| 139 | } |
|---|