root/library/bdm/stat/merger.h @ 488

Revision 488, 10.2 kB (checked in by smidl, 15 years ago)

changes in mpdf -> compile OK, broken tests!

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm {
20using std::string;
21
22//!Merging methods
23enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
24
25/*!
26@brief Base class for general combination of pdfs on discrete support
27
28Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
29
30The merged pdfs are expected to be of the form:
31 \f[ f(x_i|y_i),  i=1..n \f]
32where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
33Note that all variables will be joined.
34
35As a result of this feature, each source must be extended to common support
36\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
37where \f$ z_i \f$ accumulate variables that were not in the original source.
38These extensions are calculated on-the-fly.
39
40However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
41For merging of more general cases, use offsprings merger_mix and merger_grid.
42*/
43
44class merger_base : public compositepdf, public epdf {
45protected:
46        //! Data link for each mpdf in mpdfs
47        Array<datalink_m2e*> dls;
48        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
49        Array<RV> rvzs;
50        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
51        Array<datalink_m2e*> zdls;
52        //! number of support points
53        int Npoints;
54        //! number of sources
55        int Nsources;
56
57        //! switch of the methoh used for merging
58        MERGER_METHOD METHOD;
59        //! Default for METHOD
60        static const MERGER_METHOD DFLT_METHOD;
61
62        //!Prior on the log-normal merging model
63        double beta;
64        //! default for beta
65        static const double DFLT_beta;
66
67        //! Projection to empirical density (could also be piece-wise linear)
68        eEmp eSmp;
69
70        //! debug or not debug
71        bool DBG;
72
73        //! debugging file
74        it_file* dbg_file;
75public:
76        //! \name Constructors
77        //! @{
78
79        //!Empty constructor
80        merger_base () : compositepdf() {
81                DBG = false;
82                dbg_file = NULL;
83        }
84
85        //!Constructor from sources
86        merger_base ( const Array<mpdf*> &S, bool own = false );
87
88        //! Function setting the main internal structures
89        void set_sources ( const Array<mpdf*> &Sources, bool own ) {
90                compositepdf::set_elements ( Sources, own );
91                Nsources = mpdfs.length();
92                //set sizes
93                dls.set_size ( Sources.length() );
94                rvzs.set_size ( Sources.length() );
95                zdls.set_size ( Sources.length() );
96
97                rv = getrv ( /* checkoverlap = */ false );
98                RV rvc;
99                setrvc ( rv, rvc );  // Extend rv by rvc!
100                // join rv and rvc - see descriprion
101                rv.add ( rvc );
102                // get dimension
103                dim = rv._dsize();
104
105                // create links between sources and common rv
106                RV xytmp;
107                for ( int i = 0; i < mpdfs.length(); i++ ) {
108                        //Establich connection between mpdfs and merger
109                        dls ( i ) = new datalink_m2e;
110                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
111
112                        // find out what is missing in each mpdf
113                        xytmp = mpdfs ( i )->_rv();
114                        xytmp.add ( mpdfs ( i )->_rvc() );
115                        // z_i = common_rv-xy
116                        rvzs ( i ) = rv.subt ( xytmp );
117                        //establish connection between extension (z_i|x,y)s and common rv
118                        zdls ( i ) = new datalink_m2e;
119                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
120                };
121        }
122        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
123        void set_support ( const Array<vec> &XYZ, const int dimsize ) {
124                set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
125        }
126        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
127        void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
128                int dim = XYZ.length();  //check with internal dim!!
129                Npoints = prod ( gridsize );
130                eSmp.set_parameters ( Npoints, false );
131                Array<vec> &samples = eSmp._samples();
132                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
133                //set samples
134                ivec ind = zeros_i ( dim );    //indeces of dimensions in for cycle;
135                vec smpi ( dim );          // ith sample
136                vec steps = zeros ( dim );        // ith sample
137                // first corner
138                for ( int j = 0; j < dim; j++ ) {
139                        smpi ( j ) = XYZ ( j ) ( 0 ); /* beginning of the interval*/
140                        it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
141                        steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
142                }
143                // fill samples
144                for ( int i = 0; i < Npoints; i++ ) {
145                        // copy
146                        samples ( i ) = smpi;
147                        // go through all dimensions
148                        for ( int j = 0; j < dim; j++ ) {
149                                if ( ind ( j ) == gridsize ( j ) - 1 ) { //j-th index is full
150                                        ind ( j ) = 0; //shift back
151                                        smpi ( j ) = XYZ ( j ) ( 0 );
152
153                                        if ( i < Npoints - 1 ) {
154                                                ind ( j + 1 ) ++; //increase the next dimension;
155                                                smpi ( j + 1 ) += steps ( j + 1 );
156                                                break;
157                                        }
158
159                                } else {
160                                        ind ( j ) ++;
161                                        smpi ( j ) += steps ( j );
162                                        break;
163                                }
164                        }
165                }
166        }
167        //! set debug file
168        void set_debug_file ( const string fname ) {
169                if ( DBG ) delete dbg_file;
170                dbg_file = new it_file ( fname );
171                if ( dbg_file ) DBG = true;
172        }
173        //! Set internal parameters used in approximation
174        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
175                METHOD = MTH;
176                beta = beta0;
177        }
178        //! Set support points from a pdf by drawing N samples
179        void set_support ( const epdf &overall, int N ) {
180                eSmp.set_statistics ( overall, N );
181                Npoints = N;
182        }
183
184        //! Destructor
185        virtual ~merger_base() {
186                for ( int i = 0; i < Nsources; i++ ) {
187                        delete dls ( i );
188                        delete zdls ( i );
189                }
190                if ( DBG ) delete dbg_file;
191        };
192        //!@}
193
194        //! \name Mathematical operations
195        //!@{
196
197        //!Merge given sources in given points
198        virtual void merge () {
199                validate();
200
201                //check if sources overlap:
202                bool OK = true;
203                for ( int i = 0; i < mpdfs.length(); i++ ) {
204                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
205                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
206                }
207
208                if ( OK ) {
209                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
210
211                        vec emptyvec ( 0 );
212                        for ( int i = 0; i < mpdfs.length(); i++ ) {
213                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
214                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
215                                }
216                        }
217
218                        vec w_nn = merge_points ( lW );
219                        vec wtmp = exp ( w_nn - max ( w_nn ) );
220                        //renormalize
221                        eSmp._w() = wtmp / sum ( wtmp );
222                } else {
223                        it_error ( "Sources are not compatible - use merger_mix" );
224                }
225        };
226
227
228        //! Merge log-likelihood values in points using method specified by parameter METHOD
229        vec merge_points ( mat &lW );
230
231
232        //! sample from merged density
233//! weight w is a
234        vec mean() const {
235                const Vec<double> &w = eSmp._w();
236                const Array<vec> &S = eSmp._samples();
237                vec tmp = zeros ( dim );
238                for ( int i = 0; i < Npoints; i++ ) {
239                        tmp += w ( i ) * S ( i );
240                }
241                return tmp;
242        }
243        mat covariance() const {
244                const vec &w = eSmp._w();
245                const Array<vec> &S = eSmp._samples();
246
247                vec mea = mean();
248
249//                      cout << sum (w) << "," << w*w << endl;
250
251                mat Tmp = zeros ( dim, dim );
252                for ( int i = 0; i < Npoints; i++ ) {
253                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
254                }
255                return Tmp - outer_product ( mea, mea );
256        }
257        vec variance() const {
258                const vec &w = eSmp._w();
259                const Array<vec> &S = eSmp._samples();
260
261                vec tmp = zeros ( dim );
262                for ( int i = 0; i < Nsources; i++ ) {
263                        tmp += w ( i ) * pow ( S ( i ), 2 );
264                }
265                return tmp - pow ( mean(), 2 );
266        }
267        //!@}
268
269        //! \name Access to attributes
270        //! @{
271
272        //! Access function
273        eEmp& _Smp() {
274                return eSmp;
275        }
276
277        //! load from setting
278        void from_setting ( const Setting& set ) {
279                // get support
280                // find which method to use
281                string meth_str;
282                UI::get<string> ( meth_str, set, "method", UI::compulsory );
283                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
284                        set_method ( ARITHMETIC );
285                else {
286                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
287                                set_method ( GEOMETRIC );
288                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
289                                set_method ( LOGNORMAL );
290                                set.lookupValue ( "beta", beta );
291                        }
292                }
293                string dbg_file;
294                if ( UI::get ( dbg_file, set, "dbg_file" ) )
295                        set_debug_file ( dbg_file );
296                //validate() - not used
297        }
298
299        void validate() {
300                it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
301                it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
302                it_assert ( isnamed(), "mergers must be named" );
303        }
304        //!@}
305};
306UIREGISTER ( merger_base );
307
308class merger_mix : public merger_base {
309protected:
310        //!Internal mixture of EF models
311        MixEF Mix;
312        //!Number of components in a mixture
313        int Ncoms;
314        //! coefficient of resampling [0,1]
315        double effss_coef;
316        //! stop after niter iterations
317        int stop_niter;
318
319        //! default value for Ncoms
320        static const int DFLT_Ncoms;
321        //! default value for efss_coef;
322        static const double DFLT_effss_coef;
323
324public:
325        //!\name Constructors
326        //!@{
327        merger_mix () {};
328        merger_mix ( const Array<mpdf*> &S, bool own = false ) {
329                set_sources ( S, own );
330        };
331        //! Set sources and prepare all internal structures
332        void set_sources ( const Array<mpdf*> &S, bool own ) {
333                merger_base::set_sources ( S, own );
334                Nsources = S.length();
335        }
336        //! Set internal parameters used in approximation
337        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
338                Ncoms = Ncoms0;
339                effss_coef = effss_coef0;
340        }
341        //!@}
342
343        //! \name Mathematical operations
344        //!@{
345
346        //!Merge values using mixture approximation
347        void merge ();
348
349        //! sample from the approximating mixture
350        vec sample () const {
351                return Mix.posterior().sample();
352        }
353        //! loglikelihood computed on mixture models
354        double evallog ( const vec &dt ) const {
355                vec dtf = ones ( dt.length() + 1 );
356                dtf.set_subvector ( 0, dt );
357                return Mix.logpred ( dtf );
358        }
359        //!@}
360
361        //!\name Access functions
362        //!@{
363//! Access function
364        MixEF& _Mix() {
365                return Mix;
366        }
367        //! Access function
368        emix* proposal() {
369                emix* tmp = Mix.epredictor();
370                tmp->set_rv ( rv );
371                return tmp;
372        }
373        //! from_settings
374        void from_setting ( const Setting& set ) {
375                merger_base::from_setting ( set );
376                set.lookupValue ( "ncoms", Ncoms );
377                set.lookupValue ( "effss_coef", effss_coef );
378                set.lookupValue ( "stop_niter", stop_niter );
379        }
380
381        //! @}
382
383};
384UIREGISTER ( merger_mix );
385
386}
387
388#endif // MER_H
Note: See TracBrowser for help on using the browser.