root/library/bdm/stat/emix.cpp @ 739

Revision 739, 10.2 kB (checked in by mido, 14 years ago)

the rest of h to cpp movements, with exception of from_setting and validate to avoid conflicts with Sarka

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::set_parameters ( const vec &w0, const Array<shared_ptr<epdf> > &Coms0 ) {
6        w = w0 / sum ( w0 );
7        dim = Coms0 ( 0 )->dimension();
8        bool isnamed = Coms0 ( 0 )->isnamed();
9        int i;
10        RV tmp_rv;
11        if ( isnamed ) tmp_rv = Coms0 ( 0 )->_rv();
12
13        for ( i = 0; i < w.length(); i++ ) {
14                bdm_assert ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
15                bdm_assert ( !isnamed || tmp_rv.equal ( Coms0 ( i )->_rv() ), "Component RVs do not match!" );
16        }
17
18        Coms = Coms0;
19
20        if ( isnamed ) epdf::set_rv ( tmp_rv ); //coms aer already OK, no need for set_rv
21}
22
23vec emix::sample() const {
24        //Sample which component
25        vec cumDist = cumsum ( w );
26        double u0;
27#pragma omp critical
28        u0 = UniRNG.sample();
29
30        int i = 0;
31        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
32                i++;
33        }
34
35        return Coms ( i )->sample();
36}
37
38vec emix::mean() const {
39        int i;
40        vec mu = zeros ( dim );
41        for ( i = 0; i < w.length(); i++ ) {
42                mu += w ( i ) * Coms ( i )->mean();
43        }
44        return mu;
45}
46
47vec emix::variance() const {
48        //non-central moment
49        vec mom2 = zeros ( dim );
50        for ( int i = 0; i < w.length(); i++ ) {
51                mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
52        }
53        //central moment
54        return mom2 - pow ( mean(), 2 );
55}
56
57double emix::evallog ( const vec &val ) const {
58        int i;
59        double sum = 0.0;
60        for ( i = 0; i < w.length(); i++ ) {
61                sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
62        }
63        if ( sum == 0.0 ) {
64                sum = std::numeric_limits<double>::epsilon();
65        }
66        double tmp = log ( sum );
67        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
68        return tmp;
69}
70
71vec emix::evallog_mat ( const mat &Val ) const {
72        vec x = zeros ( Val.cols() );
73        for ( int i = 0; i < w.length(); i++ ) {
74                x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
75        }
76        return log ( x );
77};
78
79mat emix::evallog_coms ( const mat &Val ) const {
80        mat X ( w.length(), Val.cols() );
81        for ( int i = 0; i < w.length(); i++ ) {
82                X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
83        }
84        return X;
85}
86
87shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
88        emix *tmp = new emix();
89        shared_ptr<epdf> narrow ( tmp );
90        marginal ( rv, *tmp );
91        return narrow;
92}
93
94void emix::marginal ( const RV &rv, emix &target ) const {
95        bdm_assert ( isnamed(), "rvs are not assigned" );
96
97        Array<shared_ptr<epdf> > Cn ( Coms.length() );
98        for ( int i = 0; i < Coms.length(); i++ ) {
99                Cn ( i ) = Coms ( i )->marginal ( rv );
100        }
101
102        target.set_parameters ( w, Cn );
103}
104
105shared_ptr<pdf> emix::condition ( const RV &rv ) const {
106        bdm_assert ( isnamed(), "rvs are not assigned" );
107        mratio *tmp = new mratio ( this, rv );
108        return shared_ptr<pdf> ( tmp );
109}
110
111void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
112        w = w0 / sum ( w0 );
113        dim = Coms0 ( 0 )->dimension();
114        int i;
115        for ( i = 0; i < w.length(); i++ ) {
116                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
117        }
118        if ( copy ) {
119                Coms.set_length ( Coms0.length() );
120                for ( i = 0; i < w.length(); i++ ) {
121                        bdm_error ( "Not implemented" );
122                        // *Coms ( i ) = *Coms0 ( i );
123                }
124                destroyComs = true;
125        } else {
126                Coms = Coms0;
127                destroyComs = false;
128        }
129}
130
131vec egiwmix::sample() const {
132        //Sample which component
133        vec cumDist = cumsum ( w );
134        double u0;
135#pragma omp critical
136        u0 = UniRNG.sample();
137
138        int i = 0;
139        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
140                i++;
141        }
142
143        return Coms ( i )->sample();
144}
145
146vec egiwmix::mean() const {
147        int i;
148        vec mu = zeros ( dim );
149        for ( i = 0; i < w.length(); i++ ) {
150                mu += w ( i ) * Coms ( i )->mean();
151        }
152        return mu;
153}
154
155vec egiwmix::variance() const {
156        // non-central moment
157        vec mom2 = zeros ( dim );
158        for ( int i = 0; i < w.length(); i++ ) {
159                // pow is overloaded, we have to use another approach
160                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
161        }
162        // central moment
163        // pow is overloaded, we have to use another approach
164        return mom2 - elem_mult ( mean(), mean() );
165}
166
167shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
168        emix *tmp = new emix();
169        shared_ptr<epdf> narrow ( tmp );
170        marginal ( rv, *tmp );
171        return narrow;
172}
173
174void egiwmix::marginal ( const RV &rv, emix &target ) const {
175        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
176
177        Array<shared_ptr<epdf> > Cn ( Coms.length() );
178        for ( int i = 0; i < Coms.length(); i++ ) {
179                Cn ( i ) = Coms ( i )->marginal ( rv );
180        }
181
182        target.set_parameters ( w, Cn );
183}
184
185egiw*   egiwmix::approx() {
186        // NB: dimx == 1 !!!
187        // The following code might look a bit spaghetti-like,
188        // consult Dedecius, K. et al.: Partial forgetting in AR models.
189
190        double sumVecCommon;                            // common part for many terms in eq.
191        int len = w.length();                           // no. of mix components
192        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
193        vec vecNu ( len );                                      // vector of dfms of components
194        vec vecD ( len );                                       // vector of LS reminders of comps.
195        vec vecCommon ( len );                          // vector of common parts
196        mat matVecsTheta;                               // matrix which rows are theta vects.
197
198        // fill in the vectors vecNu, vecD and matVecsTheta
199        for ( int i = 0; i < len; i++ ) {
200                vecNu.shift_left ( Coms ( i )->_nu() );
201                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
202                matVecsTheta.append_row ( Coms ( i )->est_theta() );
203        }
204
205        // calculate the common parts and their sum
206        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
207        sumVecCommon = sum ( vecCommon );
208
209        // LS estimator of theta
210        vec aprEstTheta ( dimLS );
211        aprEstTheta.zeros();
212        for ( int i = 0; i < len; i++ ) {
213                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
214        }
215        aprEstTheta /= sumVecCommon;
216
217
218        // LS estimator of dfm
219        double aprNu;
220        double A = log ( sumVecCommon );                // Term 'A' in equation
221
222        for ( int i = 0; i < len; i++ ) {
223                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
224        }
225
226        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
227
228
229        // LS reminder (term D(0,0) in C-syntax)
230        double aprD = aprNu / sumVecCommon;
231
232        // Aproximation of cov
233        // the following code is very numerically sensitive, thus
234        // we have to eliminate decompositions etc. as much as possible
235        mat aprC = zeros ( dimLS, dimLS );
236        for ( int i = 0; i < len; i++ ) {
237                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
238                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
239                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
240        }
241
242        // Construct GiW pdf :: BEGIN
243        ldmat aprCinv ( inv ( aprC ) );
244        vec D = concat ( aprD, aprCinv._D() );
245        mat L = eye ( dimLS + 1 );
246        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
247        L.set_submatrix ( 1, 1, aprCinv._L() );
248        ldmat aprLD ( L, D );
249
250        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
251        return aprgiw;
252};
253
254double mprod::evallogcond ( const vec &val, const vec &cond ) {
255        int i;
256        double res = 0.0;
257        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
258                /*                      if ( pdfs(i)->_rvc().count() >0) {
259                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
260                                        }
261                                        // add logarithms
262                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
263                res += pdfs ( i )->evallogcond (
264                           dls ( i )->pushdown ( val ),
265                           dls ( i )->get_cond ( val, cond )
266                       );
267        }
268        return res;
269}
270
271vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
272        vec tmp ( Dt.cols() );
273        for ( int i = 0; i < Dt.cols(); i++ ) {
274                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
275        }
276        return tmp;
277}
278
279vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
280        vec tmp ( Dt.length() );
281        for ( int i = 0; i < Dt.length(); i++ ) {
282                tmp ( i ) = evallogcond ( Dt ( i ), cond );
283        }
284        return tmp;
285}
286
287void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
288        pdfs = mFacs;
289        dls.set_size ( mFacs.length() );
290
291        rv = get_composite_rv ( pdfs, true );
292        dim = rv._dsize();
293
294        for ( int i = 0; i < pdfs.length(); i++ ) {
295                RV rvx = pdfs ( i )->_rvc().subt ( rv );
296                rvc.add ( rvx ); // add rv to common rvc
297        }
298        dimc = rvc._dsize();
299
300        // rv and rvc established = > we can link them with pdfs
301        for ( int i = 0; i < pdfs.length(); i++ ) {
302                dls ( i ) = new datalink_m2m;
303                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
304        }
305}
306
307vec mmix::samplecond ( const vec &cond ) {
308        //Sample which component
309        vec cumDist = cumsum ( w );
310        double u0;
311#pragma omp critical
312        u0 = UniRNG.sample();
313
314        int i = 0;
315        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
316                i++;
317        }
318
319        return Coms ( i )->samplecond ( cond );
320}
321
322vec eprod::mean() const {
323        vec tmp ( dim );
324        for ( int i = 0; i < epdfs.length(); i++ ) {
325                vec pom = epdfs ( i )->mean();
326                dls ( i )->pushup ( tmp, pom );
327        }
328        return tmp;
329}
330
331vec eprod::variance() const {
332        vec tmp ( dim ); //second moment
333        for ( int i = 0; i < epdfs.length(); i++ ) {
334                vec pom = epdfs ( i )->mean();
335                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
336        }
337        return tmp - pow ( mean(), 2 );
338}
339vec eprod::sample() const {
340        vec tmp ( dim );
341        for ( int i = 0; i < epdfs.length(); i++ ) {
342                vec pom = epdfs ( i )->sample();
343                dls ( i )->pushup ( tmp, pom );
344        }
345        return tmp;
346}
347double eprod::evallog ( const vec &val ) const {
348        double tmp = 0;
349        for ( int i = 0; i < epdfs.length(); i++ ) {
350                tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
351        }
352        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
353        return tmp;
354}
355
356}
357// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
358//              int i;
359//              bool rvaddok;
360//              // Create rv
361//              for ( i = 0;i < n;i++ ) {
362//                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
363//                      // If rvaddok==false, pdfs overlap => assert error.
364//                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
365//              };
366//              // Create rvc
367//              for ( i = 0;i < n;i++ ) {
368//                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
369//              };
370//
371// //           independent = true;
372//              //test rvc of pdfs and fill rvinds
373//              for ( i = 0;i < n;i++ ) {
374//                      // find ith rv in common rv
375//                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
376//                      // find ith rvc in common rv
377//                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
378//                      // find ith rvc in common rv
379//                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
380//                      //
381// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
382//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
383//              }
384//      };
Note: See TracBrowser for help on using the browser.