root/library/bdm/stat/emix.cpp @ 964

Revision 964, 7.4 kB (checked in by smidl, 14 years ago)

Corrections in ARX and PF

  • Property svn:eol-style set to native
Line 
1#include "emix.h"
2
3namespace bdm {
4
5void emix_base::validate (){
6        epdf::validate();
7        bdm_assert ( no_coms() > 0, "There has to be at least one component." );
8
9        bdm_assert ( no_coms() == w.length(), "It is obligatory to define weights of all the components." );
10
11        double sum_w = sum ( w );
12        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
13        w = w / sum_w;
14
15        dim = component ( 0 )->dimension();
16        RV rv_tmp = component ( 0 )->_rv() ;
17        bool isnamed = component( 0 )->isnamed();
18        for ( int i = 1; i < no_coms(); i++ ) {
19                bdm_assert ( dim == ( component ( i )->dimension() ), "Component sizes do not match!" );
20                isnamed &= component(i)->isnamed() & component(i)->_rv().equal(rv_tmp);
21        }
22        if (isnamed)
23                epdf::set_rv ( rv_tmp); 
24}
25
26
27
28vec emix_base::sample() const {
29        //Sample which component
30        vec cumDist = cumsum ( w );
31        double u0;
32#pragma omp critical
33        u0 = UniRNG.sample();
34
35        int i = 0;
36        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
37                i++;
38        }
39
40        return component ( i )->sample();
41}
42
43vec emix_base::mean() const {
44        int i;
45        vec mu = zeros ( dim );
46        for ( i = 0; i < w.length(); i++ ) {
47                mu += w ( i ) * component ( i )->mean();
48        }
49        return mu;
50}
51
52vec emix_base::variance() const {
53        //non-central moment
54        vec mom2 = zeros ( dim );
55        vec mom1 = zeros ( dim );
56       
57        for ( int i = 0; i < w.length(); i++ ) {
58                vec vi=component( i )->variance();
59                vec mi=component ( i )->mean();
60                mom2 += w ( i ) * ( vi + pow ( mi, 2 ) );
61                mom1 += w(i) * mi;
62        }
63        //central moment
64        return mom2 - pow ( mom1, 2 );
65}
66
67double emix_base::evallog ( const vec &val ) const {
68        int i;
69        double sum = 0.0;
70        for ( i = 0; i < w.length(); i++ ) {
71                sum += w ( i ) * exp ( component ( i )->evallog ( val ) );
72        }
73        if ( sum == 0.0 ) {
74                sum = std::numeric_limits<double>::epsilon();
75        }
76        double tmp = log ( sum );
77        bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
78        return tmp;
79}
80
81vec emix_base::evallog_mat ( const mat &Val ) const {
82        vec x = zeros ( Val.cols() );
83        for ( int i = 0; i < w.length(); i++ ) {
84                x += w ( i ) * exp ( component( i )->evallog_mat ( Val ) );
85        }
86        return log ( x );
87};
88
89mat emix_base::evallog_coms ( const mat &Val ) const {
90        mat X ( w.length(), Val.cols() );
91        for ( int i = 0; i < w.length(); i++ ) {
92                X.set_row ( i, w ( i ) *exp ( component( i )->evallog_mat ( Val ) ) );
93        }
94        return X;
95}
96
97shared_ptr<epdf> emix_base::marginal ( const RV &rv ) const {
98        emix *tmp = new emix();
99        shared_ptr<epdf> narrow ( tmp );
100        marginal ( rv, *tmp );
101        return narrow;
102}
103
104void emix_base::marginal ( const RV &rv, emix &target ) const {
105        bdm_assert ( isnamed(), "rvs are not assigned" );
106
107        Array<shared_ptr<epdf> > Cn ( no_coms() );
108        for ( int i = 0; i < no_coms(); i++ ) {
109                Cn ( i ) = component ( i )->marginal ( rv );
110        }
111
112        target._w() = w;
113        target._Coms() = Cn;
114        target.validate();
115}
116
117shared_ptr<pdf> emix_base::condition ( const RV &rv ) const {
118        bdm_assert ( isnamed(), "rvs are not assigned" );
119        mratio *tmp = new mratio ( this, rv );
120        return shared_ptr<pdf> ( tmp );
121}
122
123void emix::from_setting ( const Setting &set ) {
124        emix_base::from_setting(set);
125        UI::get ( Coms, set, "pdfs", UI::compulsory );
126        UI::get ( w, set, "weights", UI::compulsory );
127}
128void emix::to_setting  (Setting  &set) const {
129        emix_base::to_setting(set);
130        UI::save(Coms, set, "pdfs");
131        UI::save( w, set, "weights");
132}
133
134
135void    emix::validate (){
136        emix_base::validate();
137        dim = Coms ( 0 )->dimension();
138}
139
140
141double mprod::evallogcond ( const vec &val, const vec &cond ) {
142        int i;
143        double res = 0.0;
144        for ( i = pdfs.length() - 1; i >= 0; i-- ) {
145                /*                      if ( pdfs(i)->_rvc().count() >0) {
146                                                pdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
147                                        }
148                                        // add logarithms
149                                        res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
150                res += pdfs ( i )->evallogcond (
151                           dls ( i )->pushdown ( val ),
152                           dls ( i )->get_cond ( val, cond )
153                       );
154        }
155        return res;
156}
157
158vec mprod::evallogcond_mat ( const mat &Dt, const vec &cond ) {
159        vec tmp ( Dt.cols() );
160        for ( int i = 0; i < Dt.cols(); i++ ) {
161                tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
162        }
163        return tmp;
164}
165
166vec mprod::evallogcond_mat ( const Array<vec> &Dt, const vec &cond ) {
167        vec tmp ( Dt.length() );
168        for ( int i = 0; i < Dt.length(); i++ ) {
169                tmp ( i ) = evallogcond ( Dt ( i ), cond );
170        }
171        return tmp;
172}
173
174void mprod::set_elements ( const Array<shared_ptr<pdf> > &mFacs ) {
175        pdfs = mFacs;
176        dls.set_size ( mFacs.length() );
177
178        rv = get_composite_rv ( pdfs, true );
179        dim = rv._dsize();
180
181        for ( int i = 0; i < pdfs.length(); i++ ) {
182                RV rvx = pdfs ( i )->_rvc().subt ( rv );
183                rvc.add ( rvx ); // add rv to common rvc
184        }
185        dimc = rvc._dsize();
186
187        // rv and rvc established = > we can link them with pdfs
188        for ( int i = 0; i < pdfs.length(); i++ ) {
189                dls ( i ) = new datalink_m2m;
190                dls ( i )->set_connection ( pdfs ( i )->_rv(), pdfs ( i )->_rvc(), _rv(), _rvc() );
191        }
192}
193
194void mprod::from_setting ( const Setting &set ) {
195                pdf::from_setting(set);
196                Array<shared_ptr<pdf> > temp_array; 
197                UI::get ( temp_array, set, "pdfs", UI::compulsory );
198                set_elements ( temp_array );
199        }
200void    mprod::to_setting  (Setting  &set) const {
201                pdf::to_setting(set);
202                UI::save( pdfs, set, "pdfs");
203        }
204
205void mmix::validate()
206{       pdf::validate();
207        bdm_assert ( Coms.length() > 0, "There has to be at least one component." );
208
209        bdm_assert ( Coms.length() == w.length(), "It is obligatory to define weights of all the components." );
210
211        double sum_w = sum ( w );
212        bdm_assert ( sum_w != 0, "There has to be a component with non-zero weight." );
213        w = w / sum_w;
214
215
216        dim = Coms ( 0 )->dimension();
217        dimc = Coms ( 0 )->dimensionc();
218        RV rv_tmp = Coms ( 0 )->_rv();
219        RV rvc_tmp = Coms ( 0 )->_rvc();
220        bool isnamed = Coms ( 0 )->isnamed();
221        for ( int i = 1; i < Coms.length(); i++ ) {
222                bdm_assert ( dim == ( Coms ( i )->dimension() ), "Component sizes do not match!" );
223                bdm_assert ( dimc >= ( Coms ( i )->dimensionc() ), "Component sizes do not match!" );
224                isnamed &= Coms(i)->isnamed() & Coms(i)->_rv().equal(rv_tmp) & Coms(i)->_rvc().equal(rvc_tmp);
225        }
226        if (isnamed)
227        {
228                pdf::set_rv ( rv_tmp );
229                pdf::set_rvc ( rvc_tmp );
230        }
231}
232
233void mmix::from_setting ( const Setting &set ) {
234       
235        pdf::from_setting(set);
236        UI::get ( Coms, set, "pdfs", UI::compulsory );
237
238        if ( !UI::get ( w, set, "weights", UI::optional ) ) {
239                int len = Coms.length();
240                w.set_length ( len );
241                w = 1.0 / len;
242        }
243}
244
245void    mmix::to_setting  (Setting  &set) const {
246        pdf::to_setting(set);
247        UI::save( Coms, set, "pdfs");
248        UI::save( w, set, "weights");
249}
250
251vec mmix::samplecond ( const vec &cond ) {
252        //Sample which component
253        vec cumDist = cumsum ( w );
254        double u0;
255#pragma omp critical
256        u0 = UniRNG.sample();
257
258        int i = 0;
259        while ( ( cumDist ( i ) < u0 ) && ( i < ( w.length() - 1 ) ) ) {
260                i++;
261        }
262
263        return Coms ( i )->samplecond ( cond );
264}
265
266vec eprod_base::mean() const {
267        vec tmp ( dim );
268        for ( int i = 0; i < no_factors(); i++ ) {
269                vec pom = factor( i )->mean();
270                dls ( i )->pushup ( tmp, pom );
271        }
272        return tmp;
273}
274
275vec eprod_base::variance() const {
276        vec tmp ( dim ); //second moment
277        for ( int i = 0; i < no_factors(); i++ ) {
278                vec pom = factor ( i )->variance();
279                dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
280        }
281        return tmp;
282}
283vec eprod_base::sample() const {
284        vec tmp ( dim );
285        for ( int i = 0; i < no_factors(); i++ ) {
286                vec pom = factor ( i )->sample();
287                dls ( i )->pushup ( tmp, pom );
288        }
289        return tmp;
290}
291double eprod_base::evallog ( const vec &val ) const {
292        double tmp = 0;
293        for ( int i = 0; i < no_factors(); i++ ) {
294                tmp += factor ( i )->evallog ( dls ( i )->pushdown ( val ) );
295        }
296        //bdm_assert_debug ( std::isfinite ( tmp ), "Infinite" );
297        return tmp;
298}
299
300}
Note: See TracBrowser for help on using the browser.