root/library/bdm/stat/emix.cpp @ 761

Revision 760, 10.4 kB (checked in by smidl, 14 years ago)

cleanups & stuff for SYSID like estimation

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::set_parameters ( const vec &w0, const Array<shared_ptr<epdf> > &Coms0 ) {
6        w = w0 / sum ( w0 );
7       
8        bool isnamed = Coms0 ( 0 )->isnamed();
9        RV tmp_rv;
10        if ( isnamed ) tmp_rv = Coms0 ( 0 )->_rv();
11        Coms = Coms0;
12
13        if ( isnamed ) epdf::set_rv ( tmp_rv ); //coms aer already OK, no need for set_rv
14}
15
16void emix::validate (){
17        dim = Coms ( 0 )->dimension();
18        bool isnamed = Coms ( 0 )->isnamed();
19        int i;
20        RV tmp_rv;
21        if ( isnamed ) tmp_rv = Coms ( 0 )->_rv();
22        for ( i = 0; i < w.length(); i++ ) {
23                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
24                bdm_assert ( !isnamed || tmp_rv.equal ( Coms ( i )->_rv() ), "Component RVs do not match!" );
25        }
26
27}
28
29
30
31
32vec emix::sample() const {
33        //Sample which component
34        vec cumDist = cumsum ( w );
35        double u0;
36#pragma omp critical
37        u0 = UniRNG.sample();
38
39        int i = 0;
40        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
41                i++;
42        }
43
44        return Coms ( i )->sample();
45}
46
47vec emix::mean() const {
48        int i;
49        vec mu = zeros ( dim );
50        for ( i = 0; i < w.length(); i++ ) {
51                mu += w ( i ) * Coms ( i )->mean();
52        }
53        return mu;
54}
55
56vec emix::variance() const {
57        //non-central moment
58        vec mom2 = zeros ( dim );
59        for ( int i = 0; i < w.length(); i++ ) {
60                mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
61        }
62        //central moment
63        return mom2 - pow ( mean(), 2 );
64}
65
66double emix::evallog ( const vec &val ) const {
67        int i;
68        double sum = 0.0;
69        for ( i = 0; i < w.length(); i++ ) {
70                sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
71        }
72        if ( sum == 0.0 ) {
73                sum = std::numeric_limits<double>::epsilon();
74        }
75        double tmp = log ( sum );
76        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
77        return tmp;
78}
79
80vec emix::evallog_mat ( const mat &Val ) const {
81        vec x = zeros ( Val.cols() );
82        for ( int i = 0; i < w.length(); i++ ) {
83                x += w ( i ) * exp ( Coms ( i )->evallog_mat ( Val ) );
84        }
85        return log ( x );
86};
87
88mat emix::evallog_coms ( const mat &Val ) const {
89        mat X ( w.length(), Val.cols() );
90        for ( int i = 0; i < w.length(); i++ ) {
91                X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_mat ( Val ) ) );
92        }
93        return X;
94}
95
96shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
97        emix *tmp = new emix();
98        shared_ptr<epdf> narrow ( tmp );
99        marginal ( rv, *tmp );
100        return narrow;
101}
102
103void emix::marginal ( const RV &rv, emix &target ) const {
104        bdm_assert ( isnamed(), "rvs are not assigned" );
105
106        Array<shared_ptr<epdf> > Cn ( Coms.length() );
107        for ( int i = 0; i < Coms.length(); i++ ) {
108                Cn ( i ) = Coms ( i )->marginal ( rv );
109        }
110
111        target.set_parameters ( w, Cn );
112        target.validate();
113}
114
115shared_ptr<pdf> emix::condition ( const RV &rv ) const {
116        bdm_assert ( isnamed(), "rvs are not assigned" );
117        mratio *tmp = new mratio ( this, rv );
118        return shared_ptr<pdf> ( tmp );
119}
120
121void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
122        w = w0 / sum ( w0 );
123        dim = Coms0 ( 0 )->dimension();
124        int i;
125        for ( i = 0; i < w.length(); i++ ) {
126                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
127        }
128        if ( copy ) {
129                Coms.set_length ( Coms0.length() );
130                for ( i = 0; i < w.length(); i++ ) {
131                        bdm_error ( "Not implemented" );
132                        // *Coms ( i ) = *Coms0 ( i );
133                }
134                destroyComs = true;
135        } else {
136                Coms = Coms0;
137                destroyComs = false;
138        }
139}
140
141vec egiwmix::sample() const {
142        //Sample which component
143        vec cumDist = cumsum ( w );
144        double u0;
145#pragma omp critical
146        u0 = UniRNG.sample();
147
148        int i = 0;
149        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
150                i++;
151        }
152
153        return Coms ( i )->sample();
154}
155
156vec egiwmix::mean() const {
157        int i;
158        vec mu = zeros ( dim );
159        for ( i = 0; i < w.length(); i++ ) {
160                mu += w ( i ) * Coms ( i )->mean();
161        }
162        return mu;
163}
164
165vec egiwmix::variance() const {
166        // non-central moment
167        vec mom2 = zeros ( dim );
168        for ( int i = 0; i < w.length(); i++ ) {
169                // pow is overloaded, we have to use another approach
170                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
171        }
172        // central moment
173        // pow is overloaded, we have to use another approach
174        return mom2 - elem_mult ( mean(), mean() );
175}
176
177shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
178        emix *tmp = new emix();
179        shared_ptr<epdf> narrow ( tmp );
180        marginal ( rv, *tmp );
181        return narrow;
182}
183
184void egiwmix::marginal ( const RV &rv, emix &target ) const {
185        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
186
187        Array<shared_ptr<epdf> > Cn ( Coms.length() );
188        for ( int i = 0; i < Coms.length(); i++ ) {
189                Cn ( i ) = Coms ( i )->marginal ( rv );
190        }
191
192        target.set_parameters ( w, Cn );
193        target.validate();
194}
195
196egiw*   egiwmix::approx() {
197        // NB: dimx == 1 !!!
198        // The following code might look a bit spaghetti-like,
199        // consult Dedecius, K. et al.: Partial forgetting in AR models.
200
201        double sumVecCommon;                            // common part for many terms in eq.
202        int len = w.length();                           // no. of mix components
203        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
204        vec vecNu ( len );                                      // vector of dfms of components
205        vec vecD ( len );                                       // vector of LS reminders of comps.
206        vec vecCommon ( len );                          // vector of common parts
207        mat matVecsTheta;                               // matrix which rows are theta vects.
208
209        // fill in the vectors vecNu, vecD and matVecsTheta
210        for ( int i = 0; i < len; i++ ) {
211                vecNu.shift_left ( Coms ( i )->_nu() );
212                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
213                matVecsTheta.append_row ( Coms ( i )->est_theta() );
214        }
215
216        // calculate the common parts and their sum
217        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
218        sumVecCommon = sum ( vecCommon );
219
220        // LS estimator of theta
221        vec aprEstTheta ( dimLS );
222        aprEstTheta.zeros();
223        for ( int i = 0; i < len; i++ ) {
224                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
225        }
226        aprEstTheta /= sumVecCommon;
227
228
229        // LS estimator of dfm
230        double aprNu;
231        double A = log ( sumVecCommon );                // Term 'A' in equation
232
233        for ( int i = 0; i < len; i++ ) {
234                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
235        }
236
237        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
238
239
240        // LS reminder (term D(0,0) in C-syntax)
241        double aprD = aprNu / sumVecCommon;
242
243        // Aproximation of cov
244        // the following code is very numerically sensitive, thus
245        // we have to eliminate decompositions etc. as much as possible
246        mat aprC = zeros ( dimLS, dimLS );
247        for ( int i = 0; i < len; i++ ) {
248                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
249                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
250                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
251        }
252
253        // Construct GiW pdf :: BEGIN
254        ldmat aprCinv ( inv ( aprC ) );
255        vec D = concat ( aprD, aprCinv._D() );
256        mat L = eye ( dimLS + 1 );
257        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
258        L.set_submatrix ( 1, 1, aprCinv._L() );
259        ldmat aprLD ( L, D );
260
261        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
262        return aprgiw;
263};
264
265double mprod::evallogcond ( const vec &val, const vec &cond ) {
266        int i;
267        double res = 0.0;
268        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
269                /*                      if ( pdfs(i)->_rvc().count() >0) {
270                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
271                                        }
272                                        // add logarithms
273                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
274                res += pdfs ( i )->evallogcond (
275                           dls ( i )->pushdown ( val ),
276                           dls ( i )->get_cond ( val, cond )
277                       );
278        }
279        return res;
280}
281
282vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
283        vec tmp ( Dt.cols() );
284        for ( int i = 0; i < Dt.cols(); i++ ) {
285                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
286        }
287        return tmp;
288}
289
290vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
291        vec tmp ( Dt.length() );
292        for ( int i = 0; i < Dt.length(); i++ ) {
293                tmp ( i ) = evallogcond ( Dt ( i ), cond );
294        }
295        return tmp;
296}
297
298void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
299        pdfs = mFacs;
300        dls.set_size ( mFacs.length() );
301
302        rv = get_composite_rv ( pdfs, true );
303        dim = rv._dsize();
304
305        for ( int i = 0; i < pdfs.length(); i++ ) {
306                RV rvx = pdfs ( i )->_rvc().subt ( rv );
307                rvc.add ( rvx ); // add rv to common rvc
308        }
309        dimc = rvc._dsize();
310
311        // rv and rvc established = > we can link them with pdfs
312        for ( int i = 0; i < pdfs.length(); i++ ) {
313                dls ( i ) = new datalink_m2m;
314                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
315        }
316}
317
318vec mmix::samplecond ( const vec &cond ) {
319        //Sample which component
320        vec cumDist = cumsum ( w );
321        double u0;
322#pragma omp critical
323        u0 = UniRNG.sample();
324
325        int i = 0;
326        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
327                i++;
328        }
329
330        return Coms ( i )->samplecond ( cond );
331}
332
333vec eprod::mean() const {
334        vec tmp ( dim );
335        for ( int i = 0; i < epdfs.length(); i++ ) {
336                vec pom = epdfs ( i )->mean();
337                dls ( i )->pushup ( tmp, pom );
338        }
339        return tmp;
340}
341
342vec eprod::variance() const {
343        vec tmp ( dim ); //second moment
344        for ( int i = 0; i < epdfs.length(); i++ ) {
345                vec pom = epdfs ( i )->mean();
346                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
347        }
348        return tmp - pow ( mean(), 2 );
349}
350vec eprod::sample() const {
351        vec tmp ( dim );
352        for ( int i = 0; i < epdfs.length(); i++ ) {
353                vec pom = epdfs ( i )->sample();
354                dls ( i )->pushup ( tmp, pom );
355        }
356        return tmp;
357}
358double eprod::evallog ( const vec &val ) const {
359        double tmp = 0;
360        for ( int i = 0; i < epdfs.length(); i++ ) {
361                tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
362        }
363        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
364        return tmp;
365}
366
367}
368// mprod::mprod ( Array<pdf*> mFacs, bool overlap) : pdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), pdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
369//              int i;
370//              bool rvaddok;
371//              // Create rv
372//              for ( i = 0;i < n;i++ ) {
373//                      rvaddok=rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
374//                      // If rvaddok==false, pdfs overlap => assert error.
375//                      epdfs ( i ) = & ( pdfs ( i )->posterior() ); // add pointer to epdf
376//              };
377//              // Create rvc
378//              for ( i = 0;i < n;i++ ) {
379//                      rvc.add ( pdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
380//              };
381//
382// //           independent = true;
383//              //test rvc of pdfs and fill rvinds
384//              for ( i = 0;i < n;i++ ) {
385//                      // find ith rv in common rv
386//                      rvsinrv ( i ) = pdfs ( i )->_rv().dataind ( rv );
387//                      // find ith rvc in common rv
388//                      rvcinrv ( i ) = pdfs ( i )->_rvc().dataind ( rv );
389//                      // find ith rvc in common rv
390//                      irvcs_rvc ( i ) = pdfs ( i )->_rvc().dataind ( rvc );
391//                      //
392// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
393//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
394//              }
395//      };
Note: See TracBrowser for help on using the browser.