root/library/tests/emix_test.cpp @ 529

Revision 529, 4.5 kB (checked in by vbarta, 15 years ago)

defined *_ptr wrappers of shared pointers

  • Property svn:eol-style set to native
RevLine 
[386]1#include "stat/exp_family.h"
[394]2#include "stat/emix.h"
[500]3#include "mat_checks.h"
4#include "UnitTest++.h"
5#include "test_util.h"
[461]6
[500]7const double epsilon = 0.00001;
8
[254]9using namespace bdm;
[193]10
[520]11static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance );
12
13static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance);
14
[522]15TEST ( test_emix_1 ) {
16        RV x ( "{emixx }" );
17        RV y ( "{emixy }" );
[477]18        RV xy = concat ( x, y );
[520]19        vec mu0 ( "1.00054 1.0455" );
[193]20
[529]21        enorm_ldmat_ptr E1;
[504]22        E1->set_rv ( xy );
[520]23        E1->set_parameters ( mu0 , mat ( "0.740142 -0.259015; -0.259015 1.0302" ) );
[477]24
[529]25        enorm_ldmat_ptr E2;
[504]26        E2->set_rv ( xy );
27        E2->set_parameters ( "-1.2 -0.1" , mat ( "1 0.4; 0.4 0.5" ) );
[477]28
[529]29        epdf_array A1 ( 1 );
[504]30        A1 ( 0 ) = E1;
[477]31
32        emix M1;
33        M1.set_rv ( xy );
[504]34        M1.set_parameters ( vec ( "1" ), A1 );
[477]35
[500]36        // test if ARX and emix with one ARX are the same
[529]37        epdf_ptr Mm = M1.marginal ( y );
38        epdf_ptr Am = E1->marginal ( y );
39        mpdf_ptr Mc = M1.condition ( y );
40        mpdf_ptr Ac = E1->condition ( y );
[477]41
[504]42        mlnorm<ldmat> *wacnd = dynamic_cast<mlnorm<ldmat> *>( Ac.get() );
[500]43        CHECK(wacnd);
44        if ( wacnd ) {
45                CHECK_CLOSE ( mat ( "-0.349953" ), wacnd->_A(), epsilon );
46                CHECK_CLOSE ( vec ( "1.39564" ), wacnd->_mu_const(), epsilon );
47                CHECK_CLOSE ( mat ( "0.939557" ), wacnd->_R(), epsilon );
48        }
[477]49
[500]50        double same = -1.46433;
51        CHECK_CLOSE ( same, Mm->evallog ( vec_1 ( 0.0 ) ), epsilon );
52        CHECK_CLOSE ( same, Am->evallog ( vec_1 ( 0.0 ) ), epsilon );
53        CHECK_CLOSE ( 0.145974, Mc->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );
54        CHECK_CLOSE ( -1.92433, Ac->evallogcond ( vec_1 ( 0.0 ), vec_1 ( 0.0 ) ), epsilon );
[477]55
[500]56        // mixture with two components
[529]57        epdf_array A2 ( 2 );
[504]58        A2 ( 0 ) = E1;
59        A2 ( 1 ) = E2;
[477]60
61        emix M2;
62        M2.set_rv ( xy );
[504]63        M2.set_parameters ( vec ( "1" ), A2 );
[477]64
65
[500]66        // mixture normalization
67        CHECK_CLOSE ( 1.0, normcoef ( &M2, vec ( "-3 3 " ), vec ( "-3 3 " ) ), 0.1 );
[477]68
69        int N = 3;
70        mat Smp = M2.sample_m ( N );
[520]71
72        vec exp_ll ( "-5.0 -2.53563 -2.62171" );
[500]73        vec ll = M2.evallog_m ( Smp );
[520]74        CHECK_CLOSE ( exp_ll, ll, 5.0 );
[193]75
[520]76        check_mean ( M2, N, mu0, 1.0 );
[500]77
[520]78        mat observedR ( "0.740142 -0.259015; -0.259015 1.0302" );
79        check_covariance ( M2, N, observedR, 2.0);
[193]80
[529]81        epdf_ptr Mg = M2.marginal ( y );
[504]82        CHECK ( Mg.get() );
[529]83        mpdf_ptr Cn = M2.condition ( x );
[504]84        CHECK ( Cn.get() );
[193]85
[500]86        // marginal mean
87        CHECK_CLOSE ( vec ( "1.0" ), Mg->mean(), 0.1 );
[193]88}
[520]89
[522]90TEST ( test_emix_2 ) {
[529]91        int N = 10000; // number of samples
[522]92        vec mu0 ( "1.5 1.7" );
93        mat V0 ( "1.2 0.3; 0.3 5" );
94        ldmat R = ldmat ( V0 );
95
[529]96        enorm_ldmat_ptr eN;
[522]97        eN->set_parameters ( mu0, R );
98
99        vec a = "100000,10000";
100        vec b = a / 10.0;
[529]101        egamma_ptr eG;
[522]102        eG->set_parameters ( a, b );
103
104        emix eMix;
[529]105        epdf_array Coms ( 2 );
[522]106        Coms ( 0 ) = eG;
107        Coms ( 1 ) = eN;
108
109        eMix.set_parameters ( vec_2 ( 0.5, 0.5 ), Coms );
110        check_mean ( eMix, N, eMix.mean(), 0.1 );
111}
112
[520]113static void check_mean ( emix &distrib_obj, int nsamples, const vec &mean, double tolerance ) {
114        int tc = 0;
115        Array<vec> actual(CurrentContext::max_trial_count);
116        do {
117                mat smp = distrib_obj.sample_m ( nsamples );
118                vec emu = sum ( smp, 2 ) / nsamples;
119                actual( tc ) = emu;
120                ++tc;
121        } while ( ( tc < CurrentContext::max_trial_count ) &&
122                  !UnitTest::AreClose ( mean, actual( tc - 1 ), tolerance ) );
123        if ( ( tc == CurrentContext::max_trial_count ) &&
124             ( !UnitTest::AreClose ( mean, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
125                UnitTest::MemoryOutStream stream;
126                UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__);
127                stream << "Expected " << mean << " +/- " << tolerance << " but was " << actual;
128
129                UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
130        }
131}
132
133static void check_covariance ( emix &distrib_obj, int nsamples, const mat &R, double tolerance) {
134        int tc = 0;
135        Array<mat> actual(CurrentContext::max_trial_count);
136        do {
137                mat smp = distrib_obj.sample_m ( nsamples );
138                vec emu = sum ( smp, 2 ) / nsamples;
139                mat er = ( smp * smp.T() ) / nsamples - outer_product ( emu, emu );
140                actual( tc ) = er;
141                ++tc;
142        } while ( ( tc < CurrentContext::max_trial_count ) &&
143                  !UnitTest::AreClose ( R, actual( tc - 1 ), tolerance ) );
144        if ( ( tc == CurrentContext::max_trial_count ) &&
145             ( !UnitTest::AreClose ( R, actual( CurrentContext::max_trial_count - 1 ), tolerance ) ) ) {
146                UnitTest::MemoryOutStream stream;
147                UnitTest::TestDetails details(*UnitTest::CurrentTest::Details(), __LINE__);
148                stream << "Expected " << R << " +/- " << tolerance << " but was " << actual;
149
150                UnitTest::CurrentTest::Results()->OnTestFailure ( details, stream.GetText() );
151       }
152}
Note: See TracBrowser for help on using the browser.