root/library/bdm/stat/emix.cpp @ 675

Revision 675, 7.4 kB (checked in by mido, 15 years ago)

experiment: epdf as a descendat of mpdf

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix::set_parameters ( const vec &w0, const Array<shared_ptr<epdf> > &Coms0 ) {
6        w = w0 / sum ( w0 );
7        dim = Coms0 ( 0 )->dimension();
8        bool isnamed = Coms0 ( 0 )->isnamed();
9        int i;
10        RV tmp_rv;
11        if ( isnamed ) tmp_rv = Coms0 ( 0 )->_rv();
12
13        for ( i = 0; i < w.length(); i++ ) {
14                bdm_assert ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
15                bdm_assert ( !isnamed || tmp_rv.equal ( Coms0 ( i )->_rv() ), "Component RVs do not match!" );
16        }
17
18        Coms = Coms0;
19
20        if ( isnamed ) epdf::set_rv ( tmp_rv ); //coms aer already OK, no need for set_rv
21}
22
23vec emix::sample() const {
24        //Sample which component
25        vec cumDist = cumsum ( w );
26        double u0;
27#pragma omp critical
28        u0 = UniRNG.sample();
29
30        int i = 0;
31        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
32                i++;
33        }
34
35        return Coms ( i )->sample();
36}
37
38shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
39        emix *tmp = new emix();
40        shared_ptr<epdf> narrow(tmp);
41        marginal ( rv, *tmp );
42        return narrow;
43}
44
45void emix::marginal ( const RV &rv, emix &target ) const {
46        bdm_assert ( isnamed(), "rvs are not assigned" );
47
48        Array<shared_ptr<epdf> > Cn ( Coms.length() );
49        for ( int i = 0; i < Coms.length(); i++ ) {
50                Cn ( i ) = Coms ( i )->marginal ( rv );
51        }
52
53        target.set_parameters ( w, Cn );
54}
55
56shared_ptr<mpdf> emix::condition ( const RV &rv ) const {
57        bdm_assert ( isnamed(), "rvs are not assigned" );
58        mratio *tmp = new mratio ( this, rv );
59        return shared_ptr<mpdf>(tmp);
60}
61
62void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
63        w = w0 / sum ( w0 );
64        dim = Coms0 ( 0 )->dimension();
65        int i;
66        for ( i = 0; i < w.length(); i++ ) {
67                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
68        }
69        if ( copy ) {
70                Coms.set_length ( Coms0.length() );
71                for ( i = 0; i < w.length(); i++ ) {
72                        bdm_error ( "Not implemented" );
73                        // *Coms ( i ) = *Coms0 ( i );
74                }
75                destroyComs = true;
76        } else {
77                Coms = Coms0;
78                destroyComs = false;
79        }
80}
81
82vec egiwmix::sample() const {
83        //Sample which component
84        vec cumDist = cumsum ( w );
85        double u0;
86#pragma omp critical
87        u0 = UniRNG.sample();
88
89        int i = 0;
90        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
91                i++;
92        }
93
94        return Coms ( i )->sample();
95}
96
97vec egiwmix::mean() const {
98        int i;
99        vec mu = zeros ( dim );
100        for ( i = 0; i < w.length(); i++ ) {
101                mu += w ( i ) * Coms ( i )->mean();
102        }
103        return mu;
104}
105
106vec egiwmix::variance() const {
107        // non-central moment
108        vec mom2 = zeros ( dim );
109        for ( int i = 0; i < w.length(); i++ ) {
110                // pow is overloaded, we have to use another approach
111                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
112        }
113        // central moment
114        // pow is overloaded, we have to use another approach
115        return mom2 - elem_mult ( mean(), mean() );
116}
117
118shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
119        emix *tmp = new emix();
120        shared_ptr<epdf> narrow(tmp);
121        marginal ( rv, *tmp );
122        return narrow;
123}
124
125void egiwmix::marginal ( const RV &rv, emix &target ) const {
126        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
127
128        Array<shared_ptr<epdf> > Cn ( Coms.length() );
129        for ( int i = 0; i < Coms.length(); i++ ) {
130                Cn ( i ) = Coms ( i )->marginal ( rv );
131        }
132
133        target.set_parameters ( w, Cn );
134}
135
136egiw*   egiwmix::approx() {
137        // NB: dimx == 1 !!!
138        // The following code might look a bit spaghetti-like,
139        // consult Dedecius, K. et al.: Partial forgetting in AR models.
140
141        double sumVecCommon;                            // common part for many terms in eq.
142        int len = w.length();                           // no. of mix components
143        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
144        vec vecNu ( len );                                      // vector of dfms of components
145        vec vecD ( len );                                       // vector of LS reminders of comps.
146        vec vecCommon ( len );                          // vector of common parts
147        mat matVecsTheta;                               // matrix which rows are theta vects.
148
149        // fill in the vectors vecNu, vecD and matVecsTheta
150        for ( int i = 0; i < len; i++ ) {
151                vecNu.shift_left ( Coms ( i )->_nu() );
152                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
153                matVecsTheta.append_row ( Coms ( i )->est_theta() );
154        }
155
156        // calculate the common parts and their sum
157        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
158        sumVecCommon = sum ( vecCommon );
159
160        // LS estimator of theta
161        vec aprEstTheta ( dimLS );
162        aprEstTheta.zeros();
163        for ( int i = 0; i < len; i++ ) {
164                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
165        }
166        aprEstTheta /= sumVecCommon;
167
168
169        // LS estimator of dfm
170        double aprNu;
171        double A = log ( sumVecCommon );                // Term 'A' in equation
172
173        for ( int i = 0; i < len; i++ ) {
174                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
175        }
176
177        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
178
179
180        // LS reminder (term D(0,0) in C-syntax)
181        double aprD = aprNu / sumVecCommon;
182
183        // Aproximation of cov
184        // the following code is very numerically sensitive, thus
185        // we have to eliminate decompositions etc. as much as possible
186        mat aprC = zeros ( dimLS, dimLS );
187        for ( int i = 0; i < len; i++ ) {
188                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
189                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
190                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
191        }
192
193        // Construct GiW pdf :: BEGIN
194        ldmat aprCinv ( inv ( aprC ) );
195        vec D = concat ( aprD, aprCinv._D() );
196        mat L = eye ( dimLS + 1 );
197        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
198        L.set_submatrix ( 1, 1, aprCinv._L() );
199        ldmat aprLD ( L, D );
200
201        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
202        return aprgiw;
203};
204
205void mprod::set_elements (const Array<shared_ptr<mpdf> > &mFacs ) {
206        mpdfs = mFacs;
207        dls.set_size ( mFacs.length() );
208
209        rv = get_composite_rv ( mpdfs, true );
210        dim = rv._dsize();
211
212        for ( int i = 0; i < mpdfs.length(); i++ ) {
213                RV rvx = mpdfs ( i )->_rvc().subt ( rv );
214                rvc.add ( rvx ); // add rv to common rvc
215        }
216        dimc=rvc._dsize();
217
218        // rv and rvc established = > we can link them with mpdfs
219        for ( int i = 0; i < mpdfs.length(); i++ ) {
220                dls ( i ) = new datalink_m2m;
221                dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
222        }
223}
224
225vec mmix::samplecond(const vec &cond) {
226        //Sample which component
227        vec cumDist = cumsum ( w );
228        double u0;
229#pragma omp critical
230        u0 = UniRNG.sample();
231
232        int i = 0;
233        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
234                i++;
235        }
236
237        return Coms ( i )->samplecond(cond);
238}
239
240}
241// mprod::mprod ( Array<mpdf*> mFacs, bool overlap) : mpdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), mpdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
242//              int i;
243//              bool rvaddok;
244//              // Create rv
245//              for ( i = 0;i < n;i++ ) {
246//                      rvaddok=rv.add ( mpdfs ( i )->_rv() ); //add rv to common rvs.
247//                      // If rvaddok==false, mpdfs overlap => assert error.
248//                      epdfs ( i ) = & ( mpdfs ( i )->posterior() ); // add pointer to epdf
249//              };
250//              // Create rvc
251//              for ( i = 0;i < n;i++ ) {
252//                      rvc.add ( mpdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
253//              };
254//
255// //           independent = true;
256//              //test rvc of mpdfs and fill rvinds
257//              for ( i = 0;i < n;i++ ) {
258//                      // find ith rv in common rv
259//                      rvsinrv ( i ) = mpdfs ( i )->_rv().dataind ( rv );
260//                      // find ith rvc in common rv
261//                      rvcinrv ( i ) = mpdfs ( i )->_rvc().dataind ( rv );
262//                      // find ith rvc in common rv
263//                      irvcs_rvc ( i ) = mpdfs ( i )->_rvc().dataind ( rvc );
264//                      //
265// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
266//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
267//              }
268//      };
Note: See TracBrowser for help on using the browser.