root/library/bdm/stat/emix.cpp @ 956

Revision 956, 7.4 kB (checked in by sarka, 14 years ago)

to_setting

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix_base::validate (){
6        epdf::validate();
7        bdm_assert ( no_coms() > 0, "There has to be at least one component." );
8
9        bdm_assert ( no_coms() == w.length(), "It is obligatory to define weights of all the components." );
10
11        double sum_w = sum ( w );
12        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
13        w = w / sum_w;
14
15        dim = component ( 0 )->dimension();
16        RV rv_tmp = component ( 0 )->_rv() ;
17        bool isnamed = component( 0 )->isnamed();
18        for ( int i = 1; i < no_coms(); i++ ) {
19                bdm_assert ( dim == ( component ( i )->dimension() ), "Component sizes do not match!" );
20                isnamed &= component(i)->isnamed() & component(i)->_rv().equal(rv_tmp);
21        }
22        if (isnamed)
23                epdf::set_rv ( rv_tmp); 
24}
25
26
27
28vec emix_base::sample() const {
29        //Sample which component
30        vec cumDist = cumsum ( w );
31        double u0;
32#pragma omp critical
33        u0 = UniRNG.sample();
34
35        int i = 0;
36        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
37                i++;
38        }
39
40        return component ( i )->sample();
41}
42
43vec emix_base::mean() const {
44        int i;
45        vec mu = zeros ( dim );
46        for ( i = 0; i < w.length(); i++ ) {
47                mu += w ( i ) * component ( i )->mean();
48        }
49        return mu;
50}
51
52vec emix_base::variance() const {
53        //non-central moment
54        vec mom2 = zeros ( dim );
55        for ( int i = 0; i < w.length(); i++ ) {
56                mom2 += w ( i ) * ( component( i )->variance() + pow ( component ( i )->mean(), 2 ) );
57        }
58        //central moment
59        return mom2 - pow ( mean(), 2 );
60}
61
62double emix_base::evallog ( const vec &val ) const {
63        int i;
64        double sum = 0.0;
65        for ( i = 0; i < w.length(); i++ ) {
66                sum += w ( i ) * exp ( component ( i )->evallog ( val ) );
67        }
68        if ( sum == 0.0 ) {
69                sum = std::numeric_limits<double>::epsilon();
70        }
71        double tmp = log ( sum );
72        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
73        return tmp;
74}
75
76vec emix_base::evallog_mat ( const mat &Val ) const {
77        vec x = zeros ( Val.cols() );
78        for ( int i = 0; i < w.length(); i++ ) {
79                x += w ( i ) * exp ( component( i )->evallog_mat ( Val ) );
80        }
81        return log ( x );
82};
83
84mat emix_base::evallog_coms ( const mat &Val ) const {
85        mat X ( w.length(), Val.cols() );
86        for ( int i = 0; i < w.length(); i++ ) {
87                X.set_row ( i, w ( i ) *exp ( component( i )->evallog_mat ( Val ) ) );
88        }
89        return X;
90}
91
92shared_ptr<epdf> emix_base::marginal ( const RV &rv ) const {
93        emix *tmp = new emix();
94        shared_ptr<epdf> narrow ( tmp );
95        marginal ( rv, *tmp );
96        return narrow;
97}
98
99void emix_base::marginal ( const RV &rv, emix &target ) const {
100        bdm_assert ( isnamed(), "rvs are not assigned" );
101
102        Array<shared_ptr<epdf> > Cn ( no_coms() );
103        for ( int i = 0; i < no_coms(); i++ ) {
104                Cn ( i ) = component ( i )->marginal ( rv );
105        }
106
107        target._w() = w;
108        target._Coms() = Cn;
109        target.validate();
110}
111
112shared_ptr<pdf> emix_base::condition ( const RV &rv ) const {
113        bdm_assert ( isnamed(), "rvs are not assigned" );
114        mratio *tmp = new mratio ( this, rv );
115        return shared_ptr<pdf> ( tmp );
116}
117
118void emix::from_setting ( const Setting &set ) {
119        emix_base::from_setting(set);
120        UI::get ( Coms, set, "pdfs", UI::compulsory );
121        UI::get ( w, set, "weights", UI::compulsory );
122}
123void emix::to_setting  (Setting  &set) const {
124        emix_base::to_setting(set);
125        UI::save(Coms, set, "pdfs");
126        UI::save( w, set, "weights");
127}
128
129
130void    emix::validate (){
131        emix_base::validate();
132        dim = Coms ( 0 )->dimension();
133}
134
135
136double mprod::evallogcond ( const vec &val, const vec &cond ) {
137        int i;
138        double res = 0.0;
139        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
140                /*                      if ( pdfs(i)->_rvc().count() >0) {
141                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
142                                        }
143                                        // add logarithms
144                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
145                res += pdfs ( i )->evallogcond (
146                           dls ( i )->pushdown ( val ),
147                           dls ( i )->get_cond ( val, cond )
148                       );
149        }
150        return res;
151}
152
153vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
154        vec tmp ( Dt.cols() );
155        for ( int i = 0; i < Dt.cols(); i++ ) {
156                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
157        }
158        return tmp;
159}
160
161vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
162        vec tmp ( Dt.length() );
163        for ( int i = 0; i < Dt.length(); i++ ) {
164                tmp ( i ) = evallogcond ( Dt ( i ), cond );
165        }
166        return tmp;
167}
168
169void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
170        pdfs = mFacs;
171        dls.set_size ( mFacs.length() );
172
173        rv = get_composite_rv ( pdfs, true );
174        dim = rv._dsize();
175
176        for ( int i = 0; i < pdfs.length(); i++ ) {
177                RV rvx = pdfs ( i )->_rvc().subt ( rv );
178                rvc.add ( rvx ); // add rv to common rvc
179        }
180        dimc = rvc._dsize();
181
182        // rv and rvc established = > we can link them with pdfs
183        for ( int i = 0; i < pdfs.length(); i++ ) {
184                dls ( i ) = new datalink_m2m;
185                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
186        }
187}
188
189void mprod::from_setting ( const Setting &set ) {
190                pdf::from_setting(set);
191                Array<shared_ptr<pdf> > temp_array; 
192                UI::get ( temp_array, set, "pdfs", UI::compulsory );
193                set_elements ( temp_array );
194        }
195void    mprod::to_setting  (Setting  &set) const {
196                pdf::to_setting(set);
197                UI::save( pdfs, set, "pdfs");
198        }
199
200void mmix::validate()
201{       pdf::validate();
202        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
203
204        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
205
206        double sum_w = sum ( w );
207        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
208        w = w / sum_w;
209
210
211        dim = Coms ( 0 )->dimension();
212        dimc = Coms ( 0 )->dimensionc();
213        RV rv_tmp = Coms ( 0 )->_rv();
214        RV rvc_tmp = Coms ( 0 )->_rvc();
215        bool isnamed = Coms ( 0 )->isnamed();
216        for ( int i = 1; i < Coms.length(); i++ ) {
217                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
218                bdm_assert ( dimc >= ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
219                isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp) & Coms(i)->_rvc().equal(rvc_tmp);
220        }
221        if (isnamed)
222        {
223                pdf::set_rv ( rv_tmp );
224                pdf::set_rvc ( rvc_tmp );
225        }
226}
227
228void mmix::from_setting ( const Setting &set ) {
229       
230        pdf::from_setting(set);
231        UI::get ( Coms, set, "pdfs", UI::compulsory );
232
233        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
234                int len = Coms.length();
235                w.set_length ( len );
236                w = 1.0 / len;
237        }
238}
239
240void    mmix::to_setting  (Setting  &set) const {
241        pdf::to_setting(set);
242        UI::save( Coms, set, "pdfs");
243        UI::save( w, set, "weights");
244}
245
246vec mmix::samplecond ( const vec &cond ) {
247        //Sample which component
248        vec cumDist = cumsum ( w );
249        double u0;
250#pragma omp critical
251        u0 = UniRNG.sample();
252
253        int i = 0;
254        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
255                i++;
256        }
257
258        return Coms ( i )->samplecond ( cond );
259}
260
261vec eprod_base::mean() const {
262        vec tmp ( dim );
263        for ( int i = 0; i < no_factors(); i++ ) {
264                vec pom = factor( i )->mean();
265                dls ( i )->pushup ( tmp, pom );
266        }
267        return tmp;
268}
269
270vec eprod_base::variance() const {
271        vec tmp ( dim ); //second moment
272        for ( int i = 0; i < no_factors(); i++ ) {
273                vec pom = factor ( i )->variance();
274                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
275        }
276        return tmp;
277}
278vec eprod_base::sample() const {
279        vec tmp ( dim );
280        for ( int i = 0; i < no_factors(); i++ ) {
281                vec pom = factor ( i )->sample();
282                dls ( i )->pushup ( tmp, pom );
283        }
284        return tmp;
285}
286double eprod_base::evallog ( const vec &val ) const {
287        double tmp = 0;
288        for ( int i = 0; i < no_factors(); i++ ) {
289                tmp += factor ( i )->evallog ( dls ( i )->pushdown ( val ) );
290        }
291        //bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
292        return tmp;
293}
294
295}
Note: See TracBrowser for help on using the browser.