root/library/bdm/stat/emix.cpp @ 900

Revision 900, 6.7 kB (checked in by smidl, 14 years ago)

particle bug fixing

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix_base::validate (){
6        bdm_assert ( no_coms() > 0, "There has to be at least one component." );
7
8        bdm_assert ( no_coms() == w.length(), "It is obligatory to define weights of all the components." );
9
10        double sum_w = sum ( w );
11        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
12        w = w / sum_w;
13
14        dim = component ( 0 )->dimension();
15        RV rv_tmp = component ( 0 )->_rv() ;
16        bool isnamed = component( 0 )->isnamed();
17        for ( int i = 1; i < no_coms(); i++ ) {
18                bdm_assert ( dim == ( component ( i )->dimension() ), "Component sizes do not match!" );
19                isnamed &= component(i)->isnamed() & component(i)->_rv().equal(rv_tmp);
20        }
21        if (isnamed)
22                epdf::set_rv ( rv_tmp); 
23}
24
25
26
27vec emix_base::sample() const {
28        //Sample which component
29        vec cumDist = cumsum ( w );
30        double u0;
31#pragma omp critical
32        u0 = UniRNG.sample();
33
34        int i = 0;
35        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
36                i++;
37        }
38
39        return component ( i )->sample();
40}
41
42vec emix_base::mean() const {
43        int i;
44        vec mu = zeros ( dim );
45        for ( i = 0; i < w.length(); i++ ) {
46                mu += w ( i ) * component ( i )->mean();
47        }
48        return mu;
49}
50
51vec emix_base::variance() const {
52        //non-central moment
53        vec mom2 = zeros ( dim );
54        for ( int i = 0; i < w.length(); i++ ) {
55                mom2 += w ( i ) * ( component( i )->variance() + pow ( component ( i )->mean(), 2 ) );
56        }
57        //central moment
58        return mom2 - pow ( mean(), 2 );
59}
60
61double emix_base::evallog ( const vec &val ) const {
62        int i;
63        double sum = 0.0;
64        for ( i = 0; i < w.length(); i++ ) {
65                sum += w ( i ) * exp ( component ( i )->evallog ( val ) );
66        }
67        if ( sum == 0.0 ) {
68                sum = std::numeric_limits<double>::epsilon();
69        }
70        double tmp = log ( sum );
71        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
72        return tmp;
73}
74
75vec emix_base::evallog_mat ( const mat &Val ) const {
76        vec x = zeros ( Val.cols() );
77        for ( int i = 0; i < w.length(); i++ ) {
78                x += w ( i ) * exp ( component( i )->evallog_mat ( Val ) );
79        }
80        return log ( x );
81};
82
83mat emix_base::evallog_coms ( const mat &Val ) const {
84        mat X ( w.length(), Val.cols() );
85        for ( int i = 0; i < w.length(); i++ ) {
86                X.set_row ( i, w ( i ) *exp ( component( i )->evallog_mat ( Val ) ) );
87        }
88        return X;
89}
90
91shared_ptr<epdf> emix_base::marginal ( const RV &rv ) const {
92        emix *tmp = new emix();
93        shared_ptr<epdf> narrow ( tmp );
94        marginal ( rv, *tmp );
95        return narrow;
96}
97
98void emix_base::marginal ( const RV &rv, emix &target ) const {
99        bdm_assert ( isnamed(), "rvs are not assigned" );
100
101        Array<shared_ptr<epdf> > Cn ( no_coms() );
102        for ( int i = 0; i < no_coms(); i++ ) {
103                Cn ( i ) = component ( i )->marginal ( rv );
104        }
105
106        target._w() = w;
107        target._Coms() = Cn;
108        target.validate();
109}
110
111shared_ptr<pdf> emix_base::condition ( const RV &rv ) const {
112        bdm_assert ( isnamed(), "rvs are not assigned" );
113        mratio *tmp = new mratio ( this, rv );
114        return shared_ptr<pdf> ( tmp );
115}
116
117void emix::from_setting ( const Setting &set ) {
118        UI::get ( Coms, set, "pdfs", UI::compulsory );
119        UI::get ( w, set, "weights", UI::compulsory );
120}
121
122void    emix::validate (){
123        emix_base::validate();
124        dim = Coms ( 0 )->dimension();
125}
126
127
128double mprod::evallogcond ( const vec &val, const vec &cond ) {
129        int i;
130        double res = 0.0;
131        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
132                /*                      if ( pdfs(i)->_rvc().count() >0) {
133                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
134                                        }
135                                        // add logarithms
136                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
137                res += pdfs ( i )->evallogcond (
138                           dls ( i )->pushdown ( val ),
139                           dls ( i )->get_cond ( val, cond )
140                       );
141        }
142        return res;
143}
144
145vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
146        vec tmp ( Dt.cols() );
147        for ( int i = 0; i < Dt.cols(); i++ ) {
148                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
149        }
150        return tmp;
151}
152
153vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
154        vec tmp ( Dt.length() );
155        for ( int i = 0; i < Dt.length(); i++ ) {
156                tmp ( i ) = evallogcond ( Dt ( i ), cond );
157        }
158        return tmp;
159}
160
161void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
162        pdfs = mFacs;
163        dls.set_size ( mFacs.length() );
164
165        rv = get_composite_rv ( pdfs, true );
166        dim = rv._dsize();
167
168        for ( int i = 0; i < pdfs.length(); i++ ) {
169                RV rvx = pdfs ( i )->_rvc().subt ( rv );
170                rvc.add ( rvx ); // add rv to common rvc
171        }
172        dimc = rvc._dsize();
173
174        // rv and rvc established = > we can link them with pdfs
175        for ( int i = 0; i < pdfs.length(); i++ ) {
176                dls ( i ) = new datalink_m2m;
177                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
178        }
179}
180
181void mmix::validate()
182{       
183        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
184
185        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
186
187        double sum_w = sum ( w );
188        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
189        w = w / sum_w;
190
191
192        dim = Coms ( 0 )->dimension();
193        dimc = Coms ( 0 )->dimensionc();
194        RV rv_tmp = Coms ( 0 )->_rv();
195        RV rvc_tmp = Coms ( 0 )->_rvc();
196        bool isnamed = Coms ( 0 )->isnamed();
197        for ( int i = 1; i < Coms.length(); i++ ) {
198                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
199                bdm_assert ( dimc >= ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
200                isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp) & Coms(i)->_rvc().equal(rvc_tmp);
201        }
202        if (isnamed)
203        {
204                pdf::set_rv ( rv_tmp );
205                pdf::set_rvc ( rvc_tmp );
206        }
207}
208
209void mmix::from_setting ( const Setting &set ) {
210        UI::get ( Coms, set, "pdfs", UI::compulsory );
211
212        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
213                int len = Coms.length();
214                w.set_length ( len );
215                w = 1.0 / len;
216        }
217}
218
219vec mmix::samplecond ( const vec &cond ) {
220        //Sample which component
221        vec cumDist = cumsum ( w );
222        double u0;
223#pragma omp critical
224        u0 = UniRNG.sample();
225
226        int i = 0;
227        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
228                i++;
229        }
230
231        return Coms ( i )->samplecond ( cond );
232}
233
234vec eprod_base::mean() const {
235        vec tmp ( dim );
236        for ( int i = 0; i < no_factors(); i++ ) {
237                vec pom = factor( i )->mean();
238                dls ( i )->pushup ( tmp, pom );
239        }
240        return tmp;
241}
242
243vec eprod_base::variance() const {
244        vec tmp ( dim ); //second moment
245        for ( int i = 0; i < no_factors(); i++ ) {
246                vec pom = factor ( i )->variance();
247                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
248        }
249        return tmp;
250}
251vec eprod_base::sample() const {
252        vec tmp ( dim );
253        for ( int i = 0; i < no_factors(); i++ ) {
254                vec pom = factor ( i )->sample();
255                dls ( i )->pushup ( tmp, pom );
256        }
257        return tmp;
258}
259double eprod_base::evallog ( const vec &val ) const {
260        double tmp = 0;
261        for ( int i = 0; i < no_factors(); i++ ) {
262                tmp += factor ( i )->evallog ( dls ( i )->pushdown ( val ) );
263        }
264        //bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
265        return tmp;
266}
267
268}
Note: See TracBrowser for help on using the browser.