root/library/bdm/stat/emix.cpp @ 773

Revision 766, 11.5 kB (checked in by mido, 14 years ago)

abstract methods restored wherever they are meaningful
macros NOT_IMPLEMENTED and NOT_IMPLEMENTED_VOID defined to make sources shorter
emix::set_parameters and mmix::set_parameters removed, corresponding acces methods created and the corresponding validate methods improved appropriately
some compilator warnings were avoided
and also a few other things cleaned up

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::validate (){
6        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
7
8        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
9
10        double sum_w = sum ( w );
11        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
12        w = w / sum_w;
13
14        int dim = Coms ( 0 )->dimension();
15        for ( int i = 1; i < Coms.length(); i++ ) {
16                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
17                bdm_assert ( Coms(i)->isnamed() , "An unnamed component is forbidden here!" );
18        }
19
20        set_rv ( Coms ( 0 )->_rv() ); 
21}
22
23void emix::from_setting ( const Setting &set ) {
24        UI::get ( Coms, set, "pdfs", UI::compulsory );
25
26        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
27                int len = Coms.length();
28                w.set_length ( len );
29                w = 1.0 / len;
30        }
31
32        validate();
33}
34
35
36vec emix::sample() const {
37        //Sample which component
38        vec cumDist = cumsum ( w );
39        double u0;
40#pragma omp critical
41        u0 = UniRNG.sample();
42
43        int i = 0;
44        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
45                i++;
46        }
47
48        return Coms ( i )->sample();
49}
50
51vec emix::mean() const {
52        int i;
53        vec mu = zeros ( dim );
54        for ( i = 0; i < w.length(); i++ ) {
55                mu += w ( i ) * Coms ( i )->mean();
56        }
57        return mu;
58}
59
60vec emix::variance() const {
61        //non-central moment
62        vec mom2 = zeros ( dim );
63        for ( int i = 0; i < w.length(); i++ ) {
64                mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
65        }
66        //central moment
67        return mom2 - pow ( mean(), 2 );
68}
69
70double emix::evallog ( const vec &val ) const {
71        int i;
72        double sum = 0.0;
73        for ( i = 0; i < w.length(); i++ ) {
74                sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
75        }
76        if ( sum == 0.0 ) {
77                sum = std::numeric_limits<double>::epsilon();
78        }
79        double tmp = log ( sum );
80        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
81        return tmp;
82}
83
84vec emix::evallog_mat ( const mat &Val ) const {
85        vec x = zeros ( Val.cols() );
86        for ( int i = 0; i < w.length(); i++ ) {
87                x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
88        }
89        return log ( x );
90};
91
92mat emix::evallog_coms ( const mat &Val ) const {
93        mat X ( w.length(), Val.cols() );
94        for ( int i = 0; i < w.length(); i++ ) {
95                X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
96        }
97        return X;
98}
99
100shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
101        emix *tmp = new emix();
102        shared_ptr<epdf> narrow ( tmp );
103        marginal ( rv, *tmp );
104        return narrow;
105}
106
107void emix::marginal ( const RV &rv, emix &target ) const {
108        bdm_assert ( isnamed(), "rvs are not assigned" );
109
110        Array<shared_ptr<epdf> > Cn ( Coms.length() );
111        for ( int i = 0; i < Coms.length(); i++ ) {
112                Cn ( i ) = Coms ( i )->marginal ( rv );
113        }
114
115        target._w() = w;
116        target._Coms() = Cn;
117        target.validate();
118}
119
120shared_ptr<pdf> emix::condition ( const RV &rv ) const {
121        bdm_assert ( isnamed(), "rvs are not assigned" );
122        mratio *tmp = new mratio ( this, rv );
123        return shared_ptr<pdf> ( tmp );
124}
125
126void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
127        w = w0 / sum ( w0 );
128        dim = Coms0 ( 0 )->dimension();
129        int i;
130        for ( i = 0; i < w.length(); i++ ) {
131                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
132        }
133        if ( copy ) {
134                Coms.set_length ( Coms0.length() );
135                for ( i = 0; i < w.length(); i++ ) {
136                        bdm_error ( "Not implemented" );
137                        // *Coms ( i ) = *Coms0 ( i );
138                }
139                destroyComs = true;
140        } else {
141                Coms = Coms0;
142                destroyComs = false;
143        }
144}
145
146vec egiwmix::sample() const {
147        //Sample which component
148        vec cumDist = cumsum ( w );
149        double u0;
150#pragma omp critical
151        u0 = UniRNG.sample();
152
153        int i = 0;
154        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
155                i++;
156        }
157
158        return Coms ( i )->sample();
159}
160
161vec egiwmix::mean() const {
162        int i;
163        vec mu = zeros ( dim );
164        for ( i = 0; i < w.length(); i++ ) {
165                mu += w ( i ) * Coms ( i )->mean();
166        }
167        return mu;
168}
169
170vec egiwmix::variance() const {
171        // non-central moment
172        vec mom2 = zeros ( dim );
173        for ( int i = 0; i < w.length(); i++ ) {
174                // pow is overloaded, we have to use another approach
175                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
176        }
177        // central moment
178        // pow is overloaded, we have to use another approach
179        return mom2 - elem_mult ( mean(), mean() );
180}
181
182shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
183        emix *tmp = new emix();
184        shared_ptr<epdf> narrow ( tmp );
185        marginal ( rv, *tmp );
186        return narrow;
187}
188
189void egiwmix::marginal ( const RV &rv, emix &target ) const {
190        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
191
192        Array<shared_ptr<epdf> > Cn ( Coms.length() );
193        for ( int i = 0; i < Coms.length(); i++ ) {
194                Cn ( i ) = Coms ( i )->marginal ( rv );
195        }
196
197        target._w() = w;
198        target._Coms() = Cn;
199        target.validate();
200}
201
202egiw*   egiwmix::approx() {
203        // NB: dimx == 1 !!!
204        // The following code might look a bit spaghetti-like,
205        // consult Dedecius, K. et al.: Partial forgetting in AR models.
206
207        double sumVecCommon;                            // common part for many terms in eq.
208        int len = w.length();                           // no. of mix components
209        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
210        vec vecNu ( len );                                      // vector of dfms of components
211        vec vecD ( len );                                       // vector of LS reminders of comps.
212        vec vecCommon ( len );                          // vector of common parts
213        mat matVecsTheta;                               // matrix which rows are theta vects.
214
215        // fill in the vectors vecNu, vecD and matVecsTheta
216        for ( int i = 0; i < len; i++ ) {
217                vecNu.shift_left ( Coms ( i )->_nu() );
218                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
219                matVecsTheta.append_row ( Coms ( i )->est_theta() );
220        }
221
222        // calculate the common parts and their sum
223        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
224        sumVecCommon = sum ( vecCommon );
225
226        // LS estimator of theta
227        vec aprEstTheta ( dimLS );
228        aprEstTheta.zeros();
229        for ( int i = 0; i < len; i++ ) {
230                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
231        }
232        aprEstTheta /= sumVecCommon;
233
234
235        // LS estimator of dfm
236        double aprNu;
237        double A = log ( sumVecCommon );                // Term 'A' in equation
238
239        for ( int i = 0; i < len; i++ ) {
240                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
241        }
242
243        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
244
245
246        // LS reminder (term D(0,0) in C-syntax)
247        double aprD = aprNu / sumVecCommon;
248
249        // Aproximation of cov
250        // the following code is very numerically sensitive, thus
251        // we have to eliminate decompositions etc. as much as possible
252        mat aprC = zeros ( dimLS, dimLS );
253        for ( int i = 0; i < len; i++ ) {
254                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
255                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
256                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
257        }
258
259        // Construct GiW pdf :: BEGIN
260        ldmat aprCinv ( inv ( aprC ) );
261        vec D = concat ( aprD, aprCinv._D() );
262        mat L = eye ( dimLS + 1 );
263        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
264        L.set_submatrix ( 1, 1, aprCinv._L() );
265        ldmat aprLD ( L, D );
266
267        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
268        return aprgiw;
269};
270
271double mprod::evallogcond ( const vec &val, const vec &cond ) {
272        int i;
273        double res = 0.0;
274        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
275                /*                      if ( pdfs(i)->_rvc().count() >0) {
276                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
277                                        }
278                                        // add logarithms
279                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
280                res += pdfs ( i )->evallogcond (
281                           dls ( i )->pushdown ( val ),
282                           dls ( i )->get_cond ( val, cond )
283                       );
284        }
285        return res;
286}
287
288vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
289        vec tmp ( Dt.cols() );
290        for ( int i = 0; i < Dt.cols(); i++ ) {
291                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
292        }
293        return tmp;
294}
295
296vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
297        vec tmp ( Dt.length() );
298        for ( int i = 0; i < Dt.length(); i++ ) {
299                tmp ( i ) = evallogcond ( Dt ( i ), cond );
300        }
301        return tmp;
302}
303
304void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
305        pdfs = mFacs;
306        dls.set_size ( mFacs.length() );
307
308        rv = get_composite_rv ( pdfs, true );
309        dim = rv._dsize();
310
311        for ( int i = 0; i < pdfs.length(); i++ ) {
312                RV rvx = pdfs ( i )->_rvc().subt ( rv );
313                rvc.add ( rvx ); // add rv to common rvc
314        }
315        dimc = rvc._dsize();
316
317        // rv and rvc established = > we can link them with pdfs
318        for ( int i = 0; i < pdfs.length(); i++ ) {
319                dls ( i ) = new datalink_m2m;
320                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
321        }
322}
323
324void mmix::validate()
325{       
326        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
327
328        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
329
330        double sum_w = sum ( w );
331        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
332        w = w / sum_w;
333
334        int dim = Coms ( 0 )->dimension();
335        int dimc = Coms ( 0 )->dimensionc();
336        for ( int i = 1; i < Coms.length(); i++ ) {
337                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
338                bdm_assert ( dimc == ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
339                bdm_assert ( Coms(i)->isnamed() , "An unnamed component is forbidden here!" );
340        }
341
342        set_rv ( Coms ( 0 )->_rv() );
343        set_rvc ( Coms ( 0 )->_rvc() );
344}
345
346void mmix::from_setting ( const Setting &set ) {
347        UI::get ( Coms, set, "pdfs", UI::compulsory );
348
349        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
350                int len = Coms.length();
351                w.set_length ( len );
352                w = 1.0 / len;
353        }
354}
355
356vec mmix::samplecond ( const vec &cond ) {
357        //Sample which component
358        vec cumDist = cumsum ( w );
359        double u0;
360#pragma omp critical
361        u0 = UniRNG.sample();
362
363        int i = 0;
364        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
365                i++;
366        }
367
368        return Coms ( i )->samplecond ( cond );
369}
370
371vec eprod::mean() const {
372        vec tmp ( dim );
373        for ( int i = 0; i < epdfs.length(); i++ ) {
374                vec pom = epdfs ( i )->mean();
375                dls ( i )->pushup ( tmp, pom );
376        }
377        return tmp;
378}
379
380vec eprod::variance() const {
381        vec tmp ( dim ); //second moment
382        for ( int i = 0; i < epdfs.length(); i++ ) {
383                vec pom = epdfs ( i )->mean();
384                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
385        }
386        return tmp - pow ( mean(), 2 );
387}
388vec eprod::sample() const {
389        vec tmp ( dim );
390        for ( int i = 0; i < epdfs.length(); i++ ) {
391                vec pom = epdfs ( i )->sample();
392                dls ( i )->pushup ( tmp, pom );
393        }
394        return tmp;
395}
396double eprod::evallog ( const vec &val ) const {
397        double tmp = 0;
398        for ( int i = 0; i < epdfs.length(); i++ ) {
399                tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
400        }
401        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
402        return tmp;
403}
404
405}
406// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
407//              int i;
408//              bool rvaddok;
409//              // Create rv
410//              for ( i = 0;i < n;i++ ) {
411//                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
412//                      // If rvaddok==false, pdfs overlap => assert error.
413//                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
414//              };
415//              // Create rvc
416//              for ( i = 0;i < n;i++ ) {
417//                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
418//              };
419//
420// //           independent = true;
421//              //test rvc of pdfs and fill rvinds
422//              for ( i = 0;i < n;i++ ) {
423//                      // find ith rv in common rv
424//                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
425//                      // find ith rvc in common rv
426//                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
427//                      // find ith rvc in common rv
428//                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
429//                      //
430// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
431//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
432//              }
433//      };
Note: See TracBrowser for help on using the browser.