1 | #include "stat/exp_family.h" |
---|
2 | #include "stat/emix.h" |
---|
3 | #include "../mat_checks.h" |
---|
4 | #include "UnitTest++.h" |
---|
5 | #include "../test_util.h" |
---|
6 | |
---|
7 | const double epsilon = 0.00001; |
---|
8 | |
---|
9 | using namespace bdm; |
---|
10 | |
---|
11 | static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ); |
---|
12 | |
---|
13 | static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance); |
---|
14 | |
---|
15 | TEST ( emix_1_test ) { |
---|
16 | RV x ( "{emixx }" ); |
---|
17 | RV y ( "{emixy }" ); |
---|
18 | RV xy = concat ( x, y ); |
---|
19 | vec mu0 ( "1.00054 1.0455" ); |
---|
20 | |
---|
21 | enorm_ldmat_ptr E1; |
---|
22 | E1->set_rv ( xy ); |
---|
23 | E1->set_parameters ( mu0 , mat ( "0.740142 -0.259015; -0.259015 1.0302" ) ); |
---|
24 | |
---|
25 | enorm_ldmat_ptr E2; |
---|
26 | E2->set_rv ( xy ); |
---|
27 | E2->set_parameters ( "-1.2 -0.1" , mat ( "1 0.4; 0.4 0.5" ) ); |
---|
28 | |
---|
29 | epdf_array A1 ( 1 ); |
---|
30 | A1 ( 0 ) = E1; |
---|
31 | |
---|
32 | emix M1; |
---|
33 | M1.set_rv ( xy ); |
---|
34 | M1.set_parameters ( vec ( "1" ), A1 ); |
---|
35 | |
---|
36 | // test if ARX and emix with one ARX are the same |
---|
37 | epdf_ptr Mm = M1.marginal ( y ); |
---|
38 | epdf_ptr Am = E1->marginal ( y ); |
---|
39 | pdf_ptr Mc = M1.condition ( y ); |
---|
40 | pdf_ptr Ac = E1->condition ( y ); |
---|
41 | |
---|
42 | mlnorm<ldmat> *wacnd = dynamic_cast<mlnorm<ldmat> *>( Ac.get() ); |
---|
43 | CHECK(wacnd); |
---|
44 | if ( wacnd ) { |
---|
45 | CHECK_CLOSE ( mat ( "-0.349953" ), wacnd->_A(), epsilon ); |
---|
46 | CHECK_CLOSE ( vec ( "1.39564" ), wacnd->_mu_const(), epsilon ); |
---|
47 | CHECK_CLOSE ( mat ( "0.939557" ), wacnd->_R(), epsilon ); |
---|
48 | } |
---|
49 | |
---|
50 | double same = -1.46433; |
---|
51 | CHECK_CLOSE ( same, Mm->evallog ( vec_1 ( 0.0 ) ), epsilon ); |
---|
52 | CHECK_CLOSE ( same, Am->evallog ( vec_1 ( 0.0 ) ), epsilon ); |
---|
53 | CHECK_CLOSE ( 0.145974, Mc->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon ); |
---|
54 | CHECK_CLOSE ( -1.92433, Ac->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon ); |
---|
55 | |
---|
56 | // mixture with two components |
---|
57 | epdf_array A2 ( 2 ); |
---|
58 | A2 ( 0 ) = E1; |
---|
59 | A2 ( 1 ) = E2; |
---|
60 | |
---|
61 | emix M2; |
---|
62 | M2.set_rv ( xy ); |
---|
63 | M2.set_parameters ( vec ( "1" ), A2 ); |
---|
64 | |
---|
65 | |
---|
66 | // mixture normalization |
---|
67 | CHECK_CLOSE ( 1.0, normcoef ( &M2, vec ( "-3 3 " ), vec ( "-3 3 " ) ), 0.1 ); |
---|
68 | |
---|
69 | int N = 3; |
---|
70 | mat Smp = M2.sample_mat ( N ); |
---|
71 | |
---|
72 | vec exp_ll ( "-5.0 -2.53563 -2.62171" ); |
---|
73 | vec ll = M2.evallog_mat ( Smp ); |
---|
74 | CHECK_CLOSE ( exp_ll, ll, 5.0 ); |
---|
75 | |
---|
76 | check_mean ( M2, N, mu0, 1.0 ); |
---|
77 | |
---|
78 | mat observedR ( "0.740142 -0.259015; -0.259015 1.0302" ); |
---|
79 | check_covariance ( M2, N, observedR, 2.0); |
---|
80 | |
---|
81 | epdf_ptr Mg = M2.marginal ( y ); |
---|
82 | CHECK ( Mg.get() ); |
---|
83 | pdf_ptr Cn = M2.condition ( x ); |
---|
84 | CHECK ( Cn.get() ); |
---|
85 | |
---|
86 | // marginal mean |
---|
87 | CHECK_CLOSE ( vec ( "1.0" ), Mg->mean(), 0.1 ); |
---|
88 | } |
---|
89 | |
---|
90 | TEST ( emix_2_test ) { |
---|
91 | int N = 10000; // number of samples |
---|
92 | vec mu0 ( "1.5 1.7" ); |
---|
93 | mat V0 ( "1.2 0.3; 0.3 5" ); |
---|
94 | ldmat R = ldmat ( V0 ); |
---|
95 | |
---|
96 | enorm_ldmat_ptr eN; |
---|
97 | eN->set_parameters ( mu0, R ); |
---|
98 | |
---|
99 | vec a = "100000,10000"; |
---|
100 | vec b = a / 10.0; |
---|
101 | egamma_ptr eG; |
---|
102 | eG->set_parameters ( a, b ); |
---|
103 | |
---|
104 | emix eMix; |
---|
105 | epdf_array Coms ( 2 ); |
---|
106 | Coms ( 0 ) = eG; |
---|
107 | Coms ( 1 ) = eN; |
---|
108 | |
---|
109 | eMix.set_parameters ( vec_2 ( 0.5, 0.5 ), Coms ); |
---|
110 | check_mean ( eMix, N, eMix.mean(), 0.1 ); |
---|
111 | } |
---|
112 | |
---|
113 | static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ) { |
---|
114 | int tc = 0; |
---|
115 | Array<vec> actual(CurrentContext::max_trial_count); |
---|
116 | do { |
---|
117 | mat smp = distrib_obj.sample_mat ( nsamples ); |
---|
118 | vec emu = sum ( smp, 2 ) / nsamples; |
---|
119 | actual( tc ) = emu; |
---|
120 | ++tc; |
---|
121 | } while ( ( tc < CurrentContext::max_trial_count ) && |
---|
122 | !UnitTest::AreClose ( mean, actual( tc - 1 ), tolerance ) ); |
---|
123 | if ( ( tc == CurrentContext::max_trial_count ) && |
---|
124 | ( !UnitTest::AreClose ( mean, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { |
---|
125 | UnitTest::MemoryOutStream stream; |
---|
126 | UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__); |
---|
127 | stream << "Expected " << mean << " +/- " << tolerance << " but was " << actual; |
---|
128 | |
---|
129 | UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); |
---|
130 | } |
---|
131 | } |
---|
132 | |
---|
133 | static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance) { |
---|
134 | int tc = 0; |
---|
135 | Array<mat> actual(CurrentContext::max_trial_count); |
---|
136 | do { |
---|
137 | mat smp = distrib_obj.sample_mat ( nsamples ); |
---|
138 | vec emu = sum ( smp, 2 ) / nsamples; |
---|
139 | mat er = ( smp * smp.T() ) / nsamples - outer_product ( emu, emu ); |
---|
140 | actual( tc ) = er; |
---|
141 | ++tc; |
---|
142 | } while ( ( tc < CurrentContext::max_trial_count ) && |
---|
143 | !UnitTest::AreClose ( R, actual( tc - 1 ), tolerance ) ); |
---|
144 | if ( ( tc == CurrentContext::max_trial_count ) && |
---|
145 | ( !UnitTest::AreClose ( R, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) { |
---|
146 | UnitTest::MemoryOutStream stream; |
---|
147 | UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__); |
---|
148 | stream << "Expected " << R << " +/- " << tolerance << " but was " << actual; |
---|
149 | |
---|
150 | UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() ); |
---|
151 | } |
---|
152 | } |
---|