root/library/bdm/stat/emix.cpp @ 780

Revision 780, 11.5 kB (checked in by smidl, 14 years ago)

fix broken mixtures - test_suite was OK with this?

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::validate (){
6        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
7
8        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
9
10        double sum_w = sum ( w );
11        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
12        w = w / sum_w;
13
14        dim = Coms ( 0 )->dimension();
15        RV rv_tmp= Coms ( 0 )->_rv() ;
16        bool isnam=true;
17        for ( int i = 0; i < Coms.length(); i++ ) {
18                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
19                isnam &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp);
20        }
21        if (isnam)
22                set_rv ( rv_tmp); 
23}
24
25void emix::from_setting ( const Setting &set ) {
26        UI::get ( Coms, set, "pdfs", UI::compulsory );
27
28        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
29                int len = Coms.length();
30                w.set_length ( len );
31                w = 1.0 / len;
32        }
33
34        validate();
35}
36
37
38vec emix::sample() const {
39        //Sample which component
40        vec cumDist = cumsum ( w );
41        double u0;
42#pragma omp critical
43        u0 = UniRNG.sample();
44
45        int i = 0;
46        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
47                i++;
48        }
49
50        return Coms ( i )->sample();
51}
52
53vec emix::mean() const {
54        int i;
55        vec mu = zeros ( dim );
56        for ( i = 0; i < w.length(); i++ ) {
57                mu += w ( i ) * Coms ( i )->mean();
58        }
59        return mu;
60}
61
62vec emix::variance() const {
63        //non-central moment
64        vec mom2 = zeros ( dim );
65        for ( int i = 0; i < w.length(); i++ ) {
66                mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
67        }
68        //central moment
69        return mom2 - pow ( mean(), 2 );
70}
71
72double emix::evallog ( const vec &val ) const {
73        int i;
74        double sum = 0.0;
75        for ( i = 0; i < w.length(); i++ ) {
76                sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
77        }
78        if ( sum == 0.0 ) {
79                sum = std::numeric_limits<double>::epsilon();
80        }
81        double tmp = log ( sum );
82        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
83        return tmp;
84}
85
86vec emix::evallog_mat ( const mat &Val ) const {
87        vec x = zeros ( Val.cols() );
88        for ( int i = 0; i < w.length(); i++ ) {
89                x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
90        }
91        return log ( x );
92};
93
94mat emix::evallog_coms ( const mat &Val ) const {
95        mat X ( w.length(), Val.cols() );
96        for ( int i = 0; i < w.length(); i++ ) {
97                X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
98        }
99        return X;
100}
101
102shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
103        emix *tmp = new emix();
104        shared_ptr<epdf> narrow ( tmp );
105        marginal ( rv, *tmp );
106        return narrow;
107}
108
109void emix::marginal ( const RV &rv, emix &target ) const {
110        bdm_assert ( isnamed(), "rvs are not assigned" );
111
112        Array<shared_ptr<epdf> > Cn ( Coms.length() );
113        for ( int i = 0; i < Coms.length(); i++ ) {
114                Cn ( i ) = Coms ( i )->marginal ( rv );
115        }
116
117        target._w() = w;
118        target._Coms() = Cn;
119        target.validate();
120}
121
122shared_ptr<pdf> emix::condition ( const RV &rv ) const {
123        bdm_assert ( isnamed(), "rvs are not assigned" );
124        mratio *tmp = new mratio ( this, rv );
125        return shared_ptr<pdf> ( tmp );
126}
127
128void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
129        w = w0 / sum ( w0 );
130        dim = Coms0 ( 0 )->dimension();
131        int i;
132        for ( i = 0; i < w.length(); i++ ) {
133                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
134        }
135        if ( copy ) {
136                Coms.set_length ( Coms0.length() );
137                for ( i = 0; i < w.length(); i++ ) {
138                        bdm_error ( "Not implemented" );
139                        // *Coms ( i ) = *Coms0 ( i );
140                }
141                destroyComs = true;
142        } else {
143                Coms = Coms0;
144                destroyComs = false;
145        }
146}
147
148vec egiwmix::sample() const {
149        //Sample which component
150        vec cumDist = cumsum ( w );
151        double u0;
152#pragma omp critical
153        u0 = UniRNG.sample();
154
155        int i = 0;
156        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
157                i++;
158        }
159
160        return Coms ( i )->sample();
161}
162
163vec egiwmix::mean() const {
164        int i;
165        vec mu = zeros ( dim );
166        for ( i = 0; i < w.length(); i++ ) {
167                mu += w ( i ) * Coms ( i )->mean();
168        }
169        return mu;
170}
171
172vec egiwmix::variance() const {
173        // non-central moment
174        vec mom2 = zeros ( dim );
175        for ( int i = 0; i < w.length(); i++ ) {
176                // pow is overloaded, we have to use another approach
177                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
178        }
179        // central moment
180        // pow is overloaded, we have to use another approach
181        return mom2 - elem_mult ( mean(), mean() );
182}
183
184shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
185        emix *tmp = new emix();
186        shared_ptr<epdf> narrow ( tmp );
187        marginal ( rv, *tmp );
188        return narrow;
189}
190
191void egiwmix::marginal ( const RV &rv, emix &target ) const {
192        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
193
194        Array<shared_ptr<epdf> > Cn ( Coms.length() );
195        for ( int i = 0; i < Coms.length(); i++ ) {
196                Cn ( i ) = Coms ( i )->marginal ( rv );
197        }
198
199        target._w() = w;
200        target._Coms() = Cn;
201        target.validate();
202}
203
204egiw*   egiwmix::approx() {
205        // NB: dimx == 1 !!!
206        // The following code might look a bit spaghetti-like,
207        // consult Dedecius, K. et al.: Partial forgetting in AR models.
208
209        double sumVecCommon;                            // common part for many terms in eq.
210        int len = w.length();                           // no. of mix components
211        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
212        vec vecNu ( len );                                      // vector of dfms of components
213        vec vecD ( len );                                       // vector of LS reminders of comps.
214        vec vecCommon ( len );                          // vector of common parts
215        mat matVecsTheta;                               // matrix which rows are theta vects.
216
217        // fill in the vectors vecNu, vecD and matVecsTheta
218        for ( int i = 0; i < len; i++ ) {
219                vecNu.shift_left ( Coms ( i )->_nu() );
220                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
221                matVecsTheta.append_row ( Coms ( i )->est_theta() );
222        }
223
224        // calculate the common parts and their sum
225        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
226        sumVecCommon = sum ( vecCommon );
227
228        // LS estimator of theta
229        vec aprEstTheta ( dimLS );
230        aprEstTheta.zeros();
231        for ( int i = 0; i < len; i++ ) {
232                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
233        }
234        aprEstTheta /= sumVecCommon;
235
236
237        // LS estimator of dfm
238        double aprNu;
239        double A = log ( sumVecCommon );                // Term 'A' in equation
240
241        for ( int i = 0; i < len; i++ ) {
242                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
243        }
244
245        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
246
247
248        // LS reminder (term D(0,0) in C-syntax)
249        double aprD = aprNu / sumVecCommon;
250
251        // Aproximation of cov
252        // the following code is very numerically sensitive, thus
253        // we have to eliminate decompositions etc. as much as possible
254        mat aprC = zeros ( dimLS, dimLS );
255        for ( int i = 0; i < len; i++ ) {
256                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
257                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
258                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
259        }
260
261        // Construct GiW pdf :: BEGIN
262        ldmat aprCinv ( inv ( aprC ) );
263        vec D = concat ( aprD, aprCinv._D() );
264        mat L = eye ( dimLS + 1 );
265        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
266        L.set_submatrix ( 1, 1, aprCinv._L() );
267        ldmat aprLD ( L, D );
268
269        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
270        return aprgiw;
271};
272
273double mprod::evallogcond ( const vec &val, const vec &cond ) {
274        int i;
275        double res = 0.0;
276        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
277                /*                      if ( pdfs(i)->_rvc().count() >0) {
278                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
279                                        }
280                                        // add logarithms
281                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
282                res += pdfs ( i )->evallogcond (
283                           dls ( i )->pushdown ( val ),
284                           dls ( i )->get_cond ( val, cond )
285                       );
286        }
287        return res;
288}
289
290vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
291        vec tmp ( Dt.cols() );
292        for ( int i = 0; i < Dt.cols(); i++ ) {
293                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
294        }
295        return tmp;
296}
297
298vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
299        vec tmp ( Dt.length() );
300        for ( int i = 0; i < Dt.length(); i++ ) {
301                tmp ( i ) = evallogcond ( Dt ( i ), cond );
302        }
303        return tmp;
304}
305
306void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
307        pdfs = mFacs;
308        dls.set_size ( mFacs.length() );
309
310        rv = get_composite_rv ( pdfs, true );
311        dim = rv._dsize();
312
313        for ( int i = 0; i < pdfs.length(); i++ ) {
314                RV rvx = pdfs ( i )->_rvc().subt ( rv );
315                rvc.add ( rvx ); // add rv to common rvc
316        }
317        dimc = rvc._dsize();
318
319        // rv and rvc established = > we can link them with pdfs
320        for ( int i = 0; i < pdfs.length(); i++ ) {
321                dls ( i ) = new datalink_m2m;
322                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
323        }
324}
325
326void mmix::validate()
327{       
328        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
329
330        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
331
332        double sum_w = sum ( w );
333        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
334        w = w / sum_w;
335
336        int dim = Coms ( 0 )->dimension();
337        int dimc = Coms ( 0 )->dimensionc();
338        for ( int i = 1; i < Coms.length(); i++ ) {
339                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
340                bdm_assert ( dimc == ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
341                bdm_assert ( Coms(i)->isnamed() , "An unnamed component is forbidden here!" );
342        }
343
344        set_rv ( Coms ( 0 )->_rv() );
345        set_rvc ( Coms ( 0 )->_rvc() );
346}
347
348void mmix::from_setting ( const Setting &set ) {
349        UI::get ( Coms, set, "pdfs", UI::compulsory );
350
351        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
352                int len = Coms.length();
353                w.set_length ( len );
354                w = 1.0 / len;
355        }
356}
357
358vec mmix::samplecond ( const vec &cond ) {
359        //Sample which component
360        vec cumDist = cumsum ( w );
361        double u0;
362#pragma omp critical
363        u0 = UniRNG.sample();
364
365        int i = 0;
366        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
367                i++;
368        }
369
370        return Coms ( i )->samplecond ( cond );
371}
372
373vec eprod::mean() const {
374        vec tmp ( dim );
375        for ( int i = 0; i < epdfs.length(); i++ ) {
376                vec pom = epdfs ( i )->mean();
377                dls ( i )->pushup ( tmp, pom );
378        }
379        return tmp;
380}
381
382vec eprod::variance() const {
383        vec tmp ( dim ); //second moment
384        for ( int i = 0; i < epdfs.length(); i++ ) {
385                vec pom = epdfs ( i )->mean();
386                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
387        }
388        return tmp - pow ( mean(), 2 );
389}
390vec eprod::sample() const {
391        vec tmp ( dim );
392        for ( int i = 0; i < epdfs.length(); i++ ) {
393                vec pom = epdfs ( i )->sample();
394                dls ( i )->pushup ( tmp, pom );
395        }
396        return tmp;
397}
398double eprod::evallog ( const vec &val ) const {
399        double tmp = 0;
400        for ( int i = 0; i < epdfs.length(); i++ ) {
401                tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
402        }
403        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
404        return tmp;
405}
406
407}
408// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
409//              int i;
410//              bool rvaddok;
411//              // Create rv
412//              for ( i = 0;i < n;i++ ) {
413//                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
414//                      // If rvaddok==false, pdfs overlap => assert error.
415//                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
416//              };
417//              // Create rvc
418//              for ( i = 0;i < n;i++ ) {
419//                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
420//              };
421//
422// //           independent = true;
423//              //test rvc of pdfs and fill rvinds
424//              for ( i = 0;i < n;i++ ) {
425//                      // find ith rv in common rv
426//                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
427//                      // find ith rvc in common rv
428//                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
429//                      // find ith rvc in common rv
430//                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
431//                      //
432// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
433//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
434//              }
435//      };
Note: See TracBrowser for help on using the browser.