root/library/bdm/stat/emix.cpp @ 460

Revision 460, 6.2 kB (checked in by smidl, 15 years ago)

Correct handling of named Coms in emix::set_parameters - fixes crash of emix_test

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm{
4
5void emix::set_parameters ( const vec &w0, const Array<epdf*> &Coms0, bool copy ) {
6        w = w0/sum ( w0 );
7        dim = Coms0(0)->dimension();
8        int i;
9        RV tmp_rv=Coms0(0)->_rv();
10        for ( i=0;i<w.length();i++ ) {
11                it_assert_debug ( dim== ( Coms0 ( i )->dimension() ),"Component sizes do not match!" ); 
12                it_assert_debug ( tmp_rv.equal( Coms0 ( i )->_rv() ),"Component RVs do not match!" );
13        }
14        if ( copy ) {
15                Coms.set_length(Coms0.length());
16                for ( i=0;i<w.length();i++ ) {it_error("Not imp...");
17                        *Coms ( i ) =*Coms0 ( i );}
18                destroyComs=true;
19        }
20        else {
21                Coms = Coms0;
22                destroyComs=false;
23        }
24        if (tmp_rv._dsize()==dim) epdf::set_rv(tmp_rv); //coms aer already OK, no need for set_rv
25}
26
27vec emix::sample() const {
28        //Sample which component
29        vec cumDist = cumsum ( w );
30        double u0;
31        #pragma omp critical
32        u0 = UniRNG.sample();
33
34        int i=0;
35        while ( ( cumDist ( i ) <u0 ) && ( i< ( w.length()-1 ) ) ) {i++;}
36
37        return Coms ( i )->sample();
38}
39
40emix* emix::marginal(const RV &rv) const{
41        it_assert_debug(isnamed(), "rvs are not assigned");
42                       
43        Array<epdf*> Cn(Coms.length());
44        for(int i=0;i<Coms.length();i++){Cn(i)=Coms(i)->marginal(rv);}
45        emix* tmp = new emix();
46        tmp->set_parameters(w,Cn,false);
47        tmp->ownComs();
48        return tmp;
49}
50
51mratio* emix::condition(const RV &rv) const{
52        it_assert_debug(isnamed(), "rvs are not assigned");
53        return new mratio(this,rv);
54};
55
56void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
57        w = w0/sum ( w0 );
58        dim = Coms0(0)->dimension();
59        int i;
60        for ( i=0;i<w.length();i++ ) {
61                it_assert_debug ( dim== ( Coms0 ( i )->dimension() ),"Component sizes do not match!" );
62        }
63        if ( copy ) {
64                Coms.set_length(Coms0.length());
65                for ( i=0;i<w.length();i++ ) {it_error("Not imp...");
66                        *Coms ( i ) =*Coms0 ( i );}
67                destroyComs=true;
68        }
69        else {
70                Coms = Coms0;
71                destroyComs=false;
72        }
73}
74
75vec egiwmix::sample() const {
76        //Sample which component
77        vec cumDist = cumsum ( w );
78        double u0;
79        #pragma omp critical
80        u0 = UniRNG.sample();
81
82        int i=0;
83        while ( ( cumDist ( i ) <u0 ) && ( i< ( w.length()-1 ) ) ) {i++;}
84
85        return Coms ( i )->sample();
86}
87
88vec egiwmix::mean() const {
89        int i; vec mu = zeros ( dim );
90        for ( i = 0;i < w.length();i++ ) {mu += w ( i ) * Coms ( i )->mean(); }
91        return mu;
92}
93
94vec egiwmix::variance() const {
95        // non-central moment
96        vec mom2 = zeros ( dim );
97        for ( int i = 0;i < w.length();i++ ) {
98                // pow is overloaded, we have to use another approach
99                mom2 += w ( i ) * (Coms(i)->variance() + elem_mult ( Coms(i)->mean(), Coms(i)->mean() )); 
100        }
101        // central moment
102        // pow is overloaded, we have to use another approach
103        return mom2 - elem_mult( mean(), mean() );
104}
105
106emix* egiwmix::marginal(const RV &rv) const{
107        it_assert_debug(isnamed(), "rvs are not assigned");
108                       
109        Array<epdf*> Cn(Coms.length());
110        for(int i=0;i<Coms.length();i++){Cn(i)=Coms(i)->marginal(rv);}
111        emix* tmp = new emix();
112        tmp->set_parameters(w,Cn,false);
113        tmp->ownComs();
114        return tmp;
115}
116
117egiw*   egiwmix::approx() {
118        // NB: dimx == 1 !!!
119        // The following code might look a bit spaghetti-like,
120        // consult Dedecius, K. et al.: Partial forgetting in AR models.
121
122        double sumVecCommon;                            // common part for many terms in eq.
123        int len = w.length();                           // no. of mix components       
124        int dimLS = Coms(1)->_V()._D().length() - 1;    // dim of LS
125        vec vecNu(len);                                 // vector of dfms of components
126        vec vecD(len);                                  // vector of LS reminders of comps.
127        vec vecCommon(len);                             // vector of common parts
128        mat matVecsTheta;                               // matrix which rows are theta vects.
129
130        // fill in the vectors vecNu, vecD and matVecsTheta
131        for ( int i=0; i<len; i++ ) {
132                vecNu.shift_left( Coms(i)->_nu() );
133                vecD.shift_left( Coms(i)->_V()._D()(0) );
134                matVecsTheta.append_row( Coms(i)->est_theta()  );
135        }
136
137        // calculate the common parts and their sum
138        vecCommon = elem_mult ( w, elem_div(vecNu, vecD) );
139        sumVecCommon = sum(vecCommon);
140
141        // LS estimator of theta
142        vec aprEstTheta(dimLS);  aprEstTheta.zeros();
143        for ( int i=0; i<len; i++ ) {
144                aprEstTheta +=  matVecsTheta.get_row( i ) * vecCommon ( i );
145        }
146        aprEstTheta /= sumVecCommon;
147       
148       
149        // LS estimator of dfm
150        double aprNu;
151        double A = log( sumVecCommon );         // Term 'A' in equation
152
153        for ( int i=0; i<len; i++ ) {
154                A += w(i) * ( log( vecD(i) ) - psi( 0.5 * vecNu(i) ) );
155        }
156
157        aprNu = ( 1 + sqrt(1 + 2 * (A - LOG2)/3 ) ) / ( 2 * (A - LOG2) );
158
159
160        // LS reminder (term D(0,0) in C-syntax)
161        double aprD = aprNu / sumVecCommon;
162
163        // Aproximation of cov
164        // the following code is very numerically sensitive, thus
165        // we have to eliminate decompositions etc. as much as possible
166        mat aprC = zeros(dimLS, dimLS);
167        for ( int i=0; i<len; i++ ) {
168                aprC += Coms(i)->est_theta_cov().to_mat() * w(i); 
169                vec tmp = ( matVecsTheta.get_row(i) - aprEstTheta );
170                aprC += vecCommon(i) * outer_product( tmp, tmp);
171        }
172
173        // Construct GiW pdf :: BEGIN
174        ldmat aprCinv ( inv(aprC) );
175        vec D = concat( aprD, aprCinv._D() );
176        mat L = eye(dimLS+1);
177        L.set_submatrix(1,0, aprCinv._L() * aprEstTheta);
178        L.set_submatrix(1,1, aprCinv._L());
179        ldmat aprLD (L, D);
180
181        egiw* aprgiw = new egiw(1, aprLD, aprNu);
182        return aprgiw;
183};
184
185}
186// mprod::mprod ( Array<mpdf*> mFacs, bool overlap) : mpdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), mpdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
187//              int i;
188//              bool rvaddok;
189//              // Create rv
190//              for ( i = 0;i < n;i++ ) {
191//                      rvaddok=rv.add ( mpdfs ( i )->_rv() ); //add rv to common rvs.
192//                      // If rvaddok==false, mpdfs overlap => assert error.
193//                      it_assert_debug(rvaddok||overlap,"mprod::mprod() input mpdfs overlap in rv!");
194//                      epdfs ( i ) = & ( mpdfs ( i )->posterior() ); // add pointer to epdf
195//              };
196//              // Create rvc
197//              for ( i = 0;i < n;i++ ) {
198//                      rvc.add ( mpdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
199//              };
200//
201// //           independent = true;
202//              //test rvc of mpdfs and fill rvinds
203//              for ( i = 0;i < n;i++ ) {
204//                      // find ith rv in common rv
205//                      rvsinrv ( i ) = mpdfs ( i )->_rv().dataind ( rv );
206//                      // find ith rvc in common rv
207//                      rvcinrv ( i ) = mpdfs ( i )->_rvc().dataind ( rv );
208//                      // find ith rvc in common rv
209//                      irvcs_rvc ( i ) = mpdfs ( i )->_rvc().dataind ( rvc );
210//                      //
211// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
212//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
213//              }
214//      };
Note: See TracBrowser for help on using the browser.