root/library/bdm/stat/emix.cpp @ 675

Revision 675, 7.4 kB (checked in by mido, 15 years ago)

experiment: epdf as a descendat of mpdf

  • Property svn:eol-style set to native
RevLine 
[394]1#include "emix.h"
[107]2
[477]3namespace bdm {
[107]4
[504]5void emix::set_parameters ( const vec &w0, const Array<shared_ptr<epdf> > &Coms0 ) {
[477]6        w = w0 / sum ( w0 );
7        dim = Coms0 ( 0 )->dimension();
8        bool isnamed = Coms0 ( 0 )->isnamed();
[107]9        int i;
[464]10        RV tmp_rv;
[477]11        if ( isnamed ) tmp_rv = Coms0 ( 0 )->_rv();
12
13        for ( i = 0; i < w.length(); i++ ) {
[620]14                bdm_assert ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
15                bdm_assert ( !isnamed || tmp_rv.equal ( Coms0 ( i )->_rv() ), "Component RVs do not match!" );
[124]16        }
[504]17
18        Coms = Coms0;
19
[477]20        if ( isnamed ) epdf::set_rv ( tmp_rv ); //coms aer already OK, no need for set_rv
[107]21}
22
[124]23vec emix::sample() const {
24        //Sample which component
[107]25        vec cumDist = cumsum ( w );
[235]26        double u0;
[477]27#pragma omp critical
[235]28        u0 = UniRNG.sample();
[107]29
[477]30        int i = 0;
31        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
32                i++;
33        }
[124]34
35        return Coms ( i )->sample();
[107]36}
[165]37
[504]38shared_ptr<epdf> emix::marginal ( const RV &rv ) const {
39        emix *tmp = new emix();
40        shared_ptr<epdf> narrow(tmp);
41        marginal ( rv, *tmp );
42        return narrow;
43}
44
45void emix::marginal ( const RV &rv, emix &target ) const {
[620]46        bdm_assert ( isnamed(), "rvs are not assigned" );
[477]47
[504]48        Array<shared_ptr<epdf> > Cn ( Coms.length() );
[477]49        for ( int i = 0; i < Coms.length(); i++ ) {
50                Cn ( i ) = Coms ( i )->marginal ( rv );
51        }
[504]52
53        target.set_parameters ( w, Cn );
[182]54}
55
[504]56shared_ptr<mpdf> emix::condition ( const RV &rv ) const {
[620]57        bdm_assert ( isnamed(), "rvs are not assigned" );
[504]58        mratio *tmp = new mratio ( this, rv );
59        return shared_ptr<mpdf>(tmp);
60}
[182]61
[333]62void egiwmix::set_parameters ( const vec &w0, const Array<egiw*> &Coms0, bool copy ) {
[477]63        w = w0 / sum ( w0 );
64        dim = Coms0 ( 0 )->dimension();
[333]65        int i;
[477]66        for ( i = 0; i < w.length(); i++ ) {
[565]67                bdm_assert_debug ( dim == ( Coms0 ( i )->dimension() ), "Component sizes do not match!" );
[333]68        }
69        if ( copy ) {
[477]70                Coms.set_length ( Coms0.length() );
71                for ( i = 0; i < w.length(); i++ ) {
[565]72                        bdm_error ( "Not implemented" );
73                        // *Coms ( i ) = *Coms0 ( i );
[477]74                }
75                destroyComs = true;
76        } else {
[333]77                Coms = Coms0;
[477]78                destroyComs = false;
[333]79        }
[254]80}
[333]81
82vec egiwmix::sample() const {
83        //Sample which component
84        vec cumDist = cumsum ( w );
85        double u0;
[477]86#pragma omp critical
[333]87        u0 = UniRNG.sample();
88
[477]89        int i = 0;
90        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
91                i++;
92        }
[333]93
94        return Coms ( i )->sample();
95}
96
97vec egiwmix::mean() const {
[477]98        int i;
99        vec mu = zeros ( dim );
100        for ( i = 0; i < w.length(); i++ ) {
101                mu += w ( i ) * Coms ( i )->mean();
102        }
[333]103        return mu;
104}
105
106vec egiwmix::variance() const {
107        // non-central moment
108        vec mom2 = zeros ( dim );
[477]109        for ( int i = 0; i < w.length(); i++ ) {
[333]110                // pow is overloaded, we have to use another approach
[477]111                mom2 += w ( i ) * ( Coms ( i )->variance() + elem_mult ( Coms ( i )->mean(), Coms ( i )->mean() ) );
[333]112        }
113        // central moment
114        // pow is overloaded, we have to use another approach
[477]115        return mom2 - elem_mult ( mean(), mean() );
[333]116}
117
[504]118shared_ptr<epdf> egiwmix::marginal ( const RV &rv ) const {
119        emix *tmp = new emix();
120        shared_ptr<epdf> narrow(tmp);
121        marginal ( rv, *tmp );
122        return narrow;
123}
124
125void egiwmix::marginal ( const RV &rv, emix &target ) const {
[565]126        bdm_assert_debug ( isnamed(), "rvs are not assigned" );
[477]127
[504]128        Array<shared_ptr<epdf> > Cn ( Coms.length() );
[477]129        for ( int i = 0; i < Coms.length(); i++ ) {
130                Cn ( i ) = Coms ( i )->marginal ( rv );
131        }
[504]132
133        target.set_parameters ( w, Cn );
[333]134}
135
136egiw*   egiwmix::approx() {
137        // NB: dimx == 1 !!!
138        // The following code might look a bit spaghetti-like,
139        // consult Dedecius, K. et al.: Partial forgetting in AR models.
140
141        double sumVecCommon;                            // common part for many terms in eq.
[477]142        int len = w.length();                           // no. of mix components
143        int dimLS = Coms ( 1 )->_V()._D().length() - 1;         // dim of LS
144        vec vecNu ( len );                                      // vector of dfms of components
145        vec vecD ( len );                                       // vector of LS reminders of comps.
146        vec vecCommon ( len );                          // vector of common parts
[333]147        mat matVecsTheta;                               // matrix which rows are theta vects.
148
149        // fill in the vectors vecNu, vecD and matVecsTheta
[477]150        for ( int i = 0; i < len; i++ ) {
151                vecNu.shift_left ( Coms ( i )->_nu() );
152                vecD.shift_left ( Coms ( i )->_V()._D() ( 0 ) );
153                matVecsTheta.append_row ( Coms ( i )->est_theta() );
[333]154        }
155
156        // calculate the common parts and their sum
[477]157        vecCommon = elem_mult ( w, elem_div ( vecNu, vecD ) );
158        sumVecCommon = sum ( vecCommon );
[333]159
160        // LS estimator of theta
[477]161        vec aprEstTheta ( dimLS );
162        aprEstTheta.zeros();
163        for ( int i = 0; i < len; i++ ) {
164                aprEstTheta +=  matVecsTheta.get_row ( i ) * vecCommon ( i );
[333]165        }
166        aprEstTheta /= sumVecCommon;
[477]167
168
[333]169        // LS estimator of dfm
170        double aprNu;
[477]171        double A = log ( sumVecCommon );                // Term 'A' in equation
[333]172
[477]173        for ( int i = 0; i < len; i++ ) {
174                A += w ( i ) * ( log ( vecD ( i ) ) - psi ( 0.5 * vecNu ( i ) ) );
[333]175        }
176
[477]177        aprNu = ( 1 + sqrt ( 1 + 2 * ( A - LOG2 ) / 3 ) ) / ( 2 * ( A - LOG2 ) );
[333]178
179
180        // LS reminder (term D(0,0) in C-syntax)
181        double aprD = aprNu / sumVecCommon;
182
183        // Aproximation of cov
184        // the following code is very numerically sensitive, thus
185        // we have to eliminate decompositions etc. as much as possible
[477]186        mat aprC = zeros ( dimLS, dimLS );
187        for ( int i = 0; i < len; i++ ) {
188                aprC += Coms ( i )->est_theta_cov().to_mat() * w ( i );
189                vec tmp = ( matVecsTheta.get_row ( i ) - aprEstTheta );
190                aprC += vecCommon ( i ) * outer_product ( tmp, tmp );
[333]191        }
192
193        // Construct GiW pdf :: BEGIN
[477]194        ldmat aprCinv ( inv ( aprC ) );
195        vec D = concat ( aprD, aprCinv._D() );
196        mat L = eye ( dimLS + 1 );
197        L.set_submatrix ( 1, 0, aprCinv._L() * aprEstTheta );
198        L.set_submatrix ( 1, 1, aprCinv._L() );
199        ldmat aprLD ( L, D );
[333]200
[477]201        egiw* aprgiw = new egiw ( 1, aprLD, aprNu );
[333]202        return aprgiw;
203};
204
[507]205void mprod::set_elements (const Array<shared_ptr<mpdf> > &mFacs ) {
206        mpdfs = mFacs;
207        dls.set_size ( mFacs.length() );
208
[675]209        rv = get_composite_rv ( mpdfs, true );
210        dim = rv._dsize();
[507]211
212        for ( int i = 0; i < mpdfs.length(); i++ ) {
213                RV rvx = mpdfs ( i )->_rvc().subt ( rv );
214                rvc.add ( rvx ); // add rv to common rvc
215        }
[601]216        dimc=rvc._dsize();
[507]217
218        // rv and rvc established = > we can link them with mpdfs
219        for ( int i = 0; i < mpdfs.length(); i++ ) {
220                dls ( i ) = new datalink_m2m;
221                dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), _rv(), _rvc() );
222        }
223}
224
[488]225vec mmix::samplecond(const vec &cond) {
226        //Sample which component
227        vec cumDist = cumsum ( w );
228        double u0;
229#pragma omp critical
230        u0 = UniRNG.sample();
231
232        int i = 0;
233        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
234                i++;
235        }
236
237        return Coms ( i )->samplecond(cond);
[333]238}
[488]239
240}
[182]241// mprod::mprod ( Array<mpdf*> mFacs, bool overlap) : mpdf ( RV(), RV() ), n ( mFacs.length() ), epdfs ( n ), mpdfs ( mFacs ), rvinds ( n ), rvcinrv ( n ), irvcs_rvc ( n ) {
[175]242//              int i;
243//              bool rvaddok;
244//              // Create rv
245//              for ( i = 0;i < n;i++ ) {
246//                      rvaddok=rv.add ( mpdfs ( i )->_rv() ); //add rv to common rvs.
247//                      // If rvaddok==false, mpdfs overlap => assert error.
[271]248//                      epdfs ( i ) = & ( mpdfs ( i )->posterior() ); // add pointer to epdf
[175]249//              };
250//              // Create rvc
251//              for ( i = 0;i < n;i++ ) {
252//                      rvc.add ( mpdfs ( i )->_rvc().subt ( rv ) ); //add rv to common rvs.
253//              };
[178]254//
[175]255// //           independent = true;
256//              //test rvc of mpdfs and fill rvinds
257//              for ( i = 0;i < n;i++ ) {
258//                      // find ith rv in common rv
259//                      rvsinrv ( i ) = mpdfs ( i )->_rv().dataind ( rv );
260//                      // find ith rvc in common rv
261//                      rvcinrv ( i ) = mpdfs ( i )->_rvc().dataind ( rv );
262//                      // find ith rvc in common rv
[182]263//                      irvcs_rvc ( i ) = mpdfs ( i )->_rvc().dataind ( rvc );
[175]264//                      //
265// /*                   if ( rvcinrv ( i ).length() >0 ) {independent = false;}
[182]266//                      if ( irvcs_rvc ( i ).length() >0 ) {independent = false;}*/
[175]267//              }
268//      };
Note: See TracBrowser for help on using the browser.