root/library/bdm/stat/emix.cpp @ 878

Revision 878, 11.7 kB (checked in by sarka, 15 years ago)

dim ze set_parameters do validate

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::validate (){
6        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
7
8        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
9
10        double sum_w = sum ( w );
11        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
12        w = w / sum_w;
13
14        dim = Coms ( 0 )->dimension();
15        RV rv_tmp = Coms ( 0 )->_rv() ;
16        bool isnamed = Coms ( 0 )->isnamed();
17        for ( int i = 1; i < Coms.length(); i++ ) {
18                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
19                isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp);
20        }
21        if (isnamed)
22                epdf::set_rv ( rv_tmp); 
23}
24
25void emix::from_setting ( const Setting &set ) {
26        UI::get ( Coms, set, "pdfs", UI::compulsory );
27
28        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
29                int len = Coms.length();
30                w.set_length ( len );
31                w = 1.0 / len;
32        }
33}
34
35
36vec emix::sample() const {
37        //Sample which component
38        vec cumDist = cumsum ( w );
39        double u0;
40#pragma omp critical
41        u0 = UniRNG.sample();
42
43        int i = 0;
44        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
45                i++;
46        }
47
48        return Coms ( i )->sample();
49}
50
51vec emix::mean() const {
52        int i;
53        vec mu = zeros ( dim );
54        for ( i = 0; i < w.length(); i++ ) {
55                mu += w ( i ) * Coms ( i )->mean();
56        }
57        return mu;
58}
59
60vec emix::variance() const {
61        //non-central moment
62        vec mom2 = zeros ( dim );
63        for ( int i = 0; i < w.length(); i++ ) {
64                mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
65        }
66        //central moment
67        return mom2 - pow ( mean(), 2 );
68}
69
70double emix::evallog ( const vec &val ) const {
71        int i;
72        double sum = 0.0;
73        for ( i = 0; i < w.length(); i++ ) {
74                sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
75        }
76        if ( sum == 0.0 ) {
77                sum = std::numeric_limits<double>::epsilon();
78        }
79        double tmp = log ( sum );
80        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
81        return tmp;
82}
83
84vec emix::evallog_mat ( const mat &Val ) const {
85        vec x = zeros ( Val.cols() );
86        for ( int i = 0; i < w.length(); i++ ) {
87                x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
88        }
89        return log ( x );
90};
91
92mat emix::evallog_coms ( const mat &Val ) const {
93        mat X ( w.length(), Val.cols() );
94        for ( int i = 0; i < w.length(); i++ ) {
95                X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
96        }
97        return X;
98}
99
100shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
101        emix *tmp = new emix();
102        shared_ptr<epdf> narrow ( tmp );
103        marginal ( rv, *tmp );
104        return narrow;
105}
106
107void emix::marginal ( const RV &rv, emix &target ) const {
108        bdm_assert ( isnamed(), "rvs are not assigned" );
109
110        Array<shared_ptr<epdf> > Cn ( Coms.length() );
111        for ( int i = 0; i < Coms.length(); i++ ) {
112                Cn ( i ) = Coms ( i )->marginal ( rv );
113        }
114
115        target._w() = w;
116        target._Coms() = Cn;
117        target.validate();
118}
119
120shared_ptr<pdf> emix::condition ( const RV &rv ) const {
121        bdm_assert ( isnamed(), "rvs are not assigned" );
122        mratio *tmp = new mratio ( this, rv );
123        return shared_ptr<pdf> ( tmp );
124}
125
126void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
127        w = w0 / sum ( w0 );
128        int i;
129        for ( i = 0; i < w.length(); i++ ) {
130                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
131        }
132        if ( copy ) {
133                Coms.set_length ( Coms0.length() );
134                for ( i = 0; i < w.length(); i++ ) {
135                        bdm_error ( "Not implemented" );
136                        // *Coms ( i ) = *Coms0 ( i );
137                }
138                destroyComs = true;
139        } else {
140                Coms = Coms0;
141                destroyComs = false;
142        }
143}
144
145void    egiwmix::validate (){
146   dim = Coms ( 0 )->dimension();
147}
148
149vec egiwmix::sample() const {
150        //Sample which component
151        vec cumDist = cumsum ( w );
152        double u0;
153#pragma omp critical
154        u0 = UniRNG.sample();
155
156        int i = 0;
157        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
158                i++;
159        }
160
161        return Coms ( i )->sample();
162}
163
164vec egiwmix::mean() const {
165        int i;
166        vec mu = zeros ( dim );
167        for ( i = 0; i < w.length(); i++ ) {
168                mu += w ( i ) * Coms ( i )->mean();
169        }
170        return mu;
171}
172
173vec egiwmix::variance() const {
174        // non-central moment
175        vec mom2 = zeros ( dim );
176        for ( int i = 0; i < w.length(); i++ ) {
177                // pow is overloaded, we have to use another approach
178                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
179        }
180        // central moment
181        // pow is overloaded, we have to use another approach
182        return mom2 - elem_mult ( mean(), mean() );
183}
184
185shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
186        emix *tmp = new emix();
187        shared_ptr<epdf> narrow ( tmp );
188        marginal ( rv, *tmp );
189        return narrow;
190}
191
192void egiwmix::marginal ( const RV &rv, emix &target ) const {
193        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
194
195        Array<shared_ptr<epdf> > Cn ( Coms.length() );
196        for ( int i = 0; i < Coms.length(); i++ ) {
197                Cn ( i ) = Coms ( i )->marginal ( rv );
198        }
199
200        target._w() = w;
201        target._Coms() = Cn;
202        target.validate();
203}
204
205egiw*   egiwmix::approx() {
206        // NB: dimx == 1 !!!
207        // The following code might look a bit spaghetti-like,
208        // consult Dedecius, K. et al.: Partial forgetting in AR models.
209
210        double sumVecCommon;                            // common part for many terms in eq.
211        int len = w.length();                           // no. of mix components
212        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
213        vec vecNu ( len );                                      // vector of dfms of components
214        vec vecD ( len );                                       // vector of LS reminders of comps.
215        vec vecCommon ( len );                          // vector of common parts
216        mat matVecsTheta;                               // matrix which rows are theta vects.
217
218        // fill in the vectors vecNu, vecD and matVecsTheta
219        for ( int i = 0; i < len; i++ ) {
220                vecNu.shift_left ( Coms ( i )->_nu() );
221                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
222                matVecsTheta.append_row ( Coms ( i )->est_theta() );
223        }
224
225        // calculate the common parts and their sum
226        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
227        sumVecCommon = sum ( vecCommon );
228
229        // LS estimator of theta
230        vec aprEstTheta ( dimLS );
231        aprEstTheta.zeros();
232        for ( int i = 0; i < len; i++ ) {
233                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
234        }
235        aprEstTheta /= sumVecCommon;
236
237
238        // LS estimator of dfm
239        double aprNu;
240        double A = log ( sumVecCommon );                // Term 'A' in equation
241
242        for ( int i = 0; i < len; i++ ) {
243                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
244        }
245
246        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
247
248
249        // LS reminder (term D(0,0) in C-syntax)
250        double aprD = aprNu / sumVecCommon;
251
252        // Aproximation of cov
253        // the following code is very numerically sensitive, thus
254        // we have to eliminate decompositions etc. as much as possible
255        mat aprC = zeros ( dimLS, dimLS );
256        for ( int i = 0; i < len; i++ ) {
257                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
258                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
259                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
260        }
261
262        // Construct GiW pdf :: BEGIN
263        ldmat aprCinv ( inv ( aprC ) );
264        vec D = concat ( aprD, aprCinv._D() );
265        mat L = eye ( dimLS + 1 );
266        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
267        L.set_submatrix ( 1, 1, aprCinv._L() );
268        ldmat aprLD ( L, D );
269
270        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
271        return aprgiw;
272};
273
274double mprod::evallogcond ( const vec &val, const vec &cond ) {
275        int i;
276        double res = 0.0;
277        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
278                /*                      if ( pdfs(i)->_rvc().count() >0) {
279                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
280                                        }
281                                        // add logarithms
282                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
283                res += pdfs ( i )->evallogcond (
284                           dls ( i )->pushdown ( val ),
285                           dls ( i )->get_cond ( val, cond )
286                       );
287        }
288        return res;
289}
290
291vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
292        vec tmp ( Dt.cols() );
293        for ( int i = 0; i < Dt.cols(); i++ ) {
294                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
295        }
296        return tmp;
297}
298
299vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
300        vec tmp ( Dt.length() );
301        for ( int i = 0; i < Dt.length(); i++ ) {
302                tmp ( i ) = evallogcond ( Dt ( i ), cond );
303        }
304        return tmp;
305}
306
307void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
308        pdfs = mFacs;
309        dls.set_size ( mFacs.length() );
310
311        rv = get_composite_rv ( pdfs, true );
312        dim = rv._dsize();
313
314        for ( int i = 0; i < pdfs.length(); i++ ) {
315                RV rvx = pdfs ( i )->_rvc().subt ( rv );
316                rvc.add ( rvx ); // add rv to common rvc
317        }
318        dimc = rvc._dsize();
319
320        // rv and rvc established = > we can link them with pdfs
321        for ( int i = 0; i < pdfs.length(); i++ ) {
322                dls ( i ) = new datalink_m2m;
323                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
324        }
325}
326
327void mmix::validate()
328{       
329        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
330
331        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
332
333        double sum_w = sum ( w );
334        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
335        w = w / sum_w;
336
337
338        dim = Coms ( 0 )->dimension();
339        dimc = Coms ( 0 )->dimensionc();
340        RV rv_tmp = Coms ( 0 )->_rv();
341        RV rvc_tmp = Coms ( 0 )->_rvc();
342        bool isnamed = Coms ( 0 )->isnamed();
343        for ( int i = 1; i < Coms.length(); i++ ) {
344                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
345                bdm_assert ( dimc >= ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
346                isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp) & Coms(i)->_rvc().equal(rvc_tmp);
347        }
348        if (isnamed)
349        {
350                pdf::set_rv ( rv_tmp );
351                pdf::set_rvc ( rvc_tmp );
352        }
353}
354
355void mmix::from_setting ( const Setting &set ) {
356        UI::get ( Coms, set, "pdfs", UI::compulsory );
357
358        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
359                int len = Coms.length();
360                w.set_length ( len );
361                w = 1.0 / len;
362        }
363}
364
365vec mmix::samplecond ( const vec &cond ) {
366        //Sample which component
367        vec cumDist = cumsum ( w );
368        double u0;
369#pragma omp critical
370        u0 = UniRNG.sample();
371
372        int i = 0;
373        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
374                i++;
375        }
376
377        return Coms ( i )->samplecond ( cond );
378}
379
380vec eprod::mean() const {
381        vec tmp ( dim );
382        for ( int i = 0; i < epdfs.length(); i++ ) {
383                vec pom = epdfs ( i )->mean();
384                dls ( i )->pushup ( tmp, pom );
385        }
386        return tmp;
387}
388
389vec eprod::variance() const {
390        vec tmp ( dim ); //second moment
391        for ( int i = 0; i < epdfs.length(); i++ ) {
392                vec pom = epdfs ( i )->mean();
393                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
394        }
395        return tmp - pow ( mean(), 2 );
396}
397vec eprod::sample() const {
398        vec tmp ( dim );
399        for ( int i = 0; i < epdfs.length(); i++ ) {
400                vec pom = epdfs ( i )->sample();
401                dls ( i )->pushup ( tmp, pom );
402        }
403        return tmp;
404}
405double eprod::evallog ( const vec &val ) const {
406        double tmp = 0;
407        for ( int i = 0; i < epdfs.length(); i++ ) {
408                tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
409        }
410        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
411        return tmp;
412}
413
414}
415// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
416//              int i;
417//              bool rvaddok;
418//              // Create rv
419//              for ( i = 0;i < n;i++ ) {
420//                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
421//                      // If rvaddok==false, pdfs overlap => assert error.
422//                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
423//              };
424//              // Create rvc
425//              for ( i = 0;i < n;i++ ) {
426//                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
427//              };
428//
429// //           independent = true;
430//              //test rvc of pdfs and fill rvinds
431//              for ( i = 0;i < n;i++ ) {
432//                      // find ith rv in common rv
433//                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
434//                      // find ith rvc in common rv
435//                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
436//                      // find ith rvc in common rv
437//                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
438//                      //
439// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
440//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
441//              }
442//      };
Note: See TracBrowser for help on using the browser.