root/library/bdm/stat/merger.h @ 536

Revision 536, 10.6 kB (checked in by smidl, 15 years ago)

removal of unused functions _e() and samplecond(,) and added documentation lines

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm {
20using std::string;
21
22//!Merging methods
23enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
24
25/*!
26@brief Base class for general combination of pdfs on discrete support
27
28Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
29
30The merged pdfs are expected to be of the form:
31 \f[ f(x_i|y_i),  i=1..n \f]
32where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
33Note that all variables will be joined.
34
35As a result of this feature, each source must be extended to common support
36\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
37where \f$ z_i \f$ accumulate variables that were not in the original source.
38These extensions are calculated on-the-fly.
39
40However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
41For merging of more general cases, use offsprings merger_mix and merger_grid.
42*/
43
44class merger_base : public epdf {
45protected:
46        //! Elements of composition
47        Array<shared_ptr<mpdf> > mpdfs;
48
49        //! Data link for each mpdf in mpdfs
50        Array<datalink_m2e*> dls;
51
52        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
53        Array<RV> rvzs;
54
55        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
56        Array<datalink_m2e*> zdls;
57
58        //! number of support points
59        int Npoints;
60
61        //! number of sources
62        int Nsources;
63
64        //! switch of the methoh used for merging
65        MERGER_METHOD METHOD;
66        //! Default for METHOD
67        static const MERGER_METHOD DFLT_METHOD;
68
69        //!Prior on the log-normal merging model
70        double beta;
71        //! default for beta
72        static const double DFLT_beta;
73
74        //! Projection to empirical density (could also be piece-wise linear)
75        eEmp eSmp;
76
77        //! debug or not debug
78        bool DBG;
79
80        //! debugging file
81        it_file* dbg_file;
82public:
83        //! \name Constructors
84        //! @{
85
86        //! Default constructor
87        merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
88        }
89
90        //!Constructor from sources
91        merger_base ( const Array<shared_ptr<mpdf> > &S );
92
93        //! Function setting the main internal structures
94        void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
95                mpdfs = Sources;
96                Nsources = mpdfs.length();
97                //set sizes
98                dls.set_size ( Sources.length() );
99                rvzs.set_size ( Sources.length() );
100                zdls.set_size ( Sources.length() );
101
102                rv = get_composite_rv ( mpdfs, /* checkoverlap = */ false );
103
104                RV rvc;
105                // Extend rv by rvc!
106                for ( int i = 0; i < mpdfs.length(); i++ ) {
107                        RV rvx = mpdfs ( i )->_rvc().subt ( rv );
108                        rvc.add ( rvx ); // add rv to common rvc
109                }
110
111                // join rv and rvc - see descriprion
112                rv.add ( rvc );
113                // get dimension
114                dim = rv._dsize();
115
116                // create links between sources and common rv
117                RV xytmp;
118                for ( int i = 0; i < mpdfs.length(); i++ ) {
119                        //Establich connection between mpdfs and merger
120                        dls ( i ) = new datalink_m2e;
121                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
122
123                        // find out what is missing in each mpdf
124                        xytmp = mpdfs ( i )->_rv();
125                        xytmp.add ( mpdfs ( i )->_rvc() );
126                        // z_i = common_rv-xy
127                        rvzs ( i ) = rv.subt ( xytmp );
128                        //establish connection between extension (z_i|x,y)s and common rv
129                        zdls ( i ) = new datalink_m2e;
130                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
131                };
132        }
133        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
134        void set_support ( const Array<vec> &XYZ, const int dimsize ) {
135                set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
136        }
137        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
138        void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
139                int dim = XYZ.length();  //check with internal dim!!
140                Npoints = prod ( gridsize );
141                eSmp.set_parameters ( Npoints, false );
142                Array<vec> &samples = eSmp._samples();
143                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
144                //set samples
145                ivec ind = zeros_i ( dim );    //indeces of dimensions in for cycle;
146                vec smpi ( dim );          // ith sample
147                vec steps = zeros ( dim );        // ith sample
148                // first corner
149                for ( int j = 0; j < dim; j++ ) {
150                        smpi ( j ) = XYZ ( j ) ( 0 ); /* beginning of the interval*/
151                        it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
152                        steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
153                }
154                // fill samples
155                for ( int i = 0; i < Npoints; i++ ) {
156                        // copy
157                        samples ( i ) = smpi;
158                        // go through all dimensions
159                        for ( int j = 0; j < dim; j++ ) {
160                                if ( ind ( j ) == gridsize ( j ) - 1 ) { //j-th index is full
161                                        ind ( j ) = 0; //shift back
162                                        smpi ( j ) = XYZ ( j ) ( 0 );
163
164                                        if ( i < Npoints - 1 ) {
165                                                ind ( j + 1 ) ++; //increase the next dimension;
166                                                smpi ( j + 1 ) += steps ( j + 1 );
167                                                break;
168                                        }
169
170                                } else {
171                                        ind ( j ) ++;
172                                        smpi ( j ) += steps ( j );
173                                        break;
174                                }
175                        }
176                }
177        }
178        //! set debug file
179        void set_debug_file ( const string fname ) {
180                if ( DBG ) delete dbg_file;
181                dbg_file = new it_file ( fname );
182                if ( dbg_file ) DBG = true;
183        }
184        //! Set internal parameters used in approximation
185        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
186                METHOD = MTH;
187                beta = beta0;
188        }
189        //! Set support points from a pdf by drawing N samples
190        void set_support ( const epdf &overall, int N ) {
191                eSmp.set_statistics ( overall, N );
192                Npoints = N;
193        }
194
195        //! Destructor
196        virtual ~merger_base() {
197                for ( int i = 0; i < Nsources; i++ ) {
198                        delete dls ( i );
199                        delete zdls ( i );
200                }
201                if ( DBG ) delete dbg_file;
202        };
203        //!@}
204
205        //! \name Mathematical operations
206        //!@{
207
208        //!Merge given sources in given points
209        virtual void merge () {
210                validate();
211
212                //check if sources overlap:
213                bool OK = true;
214                for ( int i = 0; i < mpdfs.length(); i++ ) {
215                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
216                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
217                }
218
219                if ( OK ) {
220                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
221
222                        vec emptyvec ( 0 );
223                        for ( int i = 0; i < mpdfs.length(); i++ ) {
224                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
225                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
226                                }
227                        }
228
229                        vec w_nn = merge_points ( lW );
230                        vec wtmp = exp ( w_nn - max ( w_nn ) );
231                        //renormalize
232                        eSmp._w() = wtmp / sum ( wtmp );
233                } else {
234                        it_error ( "Sources are not compatible - use merger_mix" );
235                }
236        };
237
238
239        //! Merge log-likelihood values in points using method specified by parameter METHOD
240        vec merge_points ( mat &lW );
241
242
243        //! sample from merged density
244//! weight w is a
245        vec mean() const {
246                const Vec<double> &w = eSmp._w();
247                const Array<vec> &S = eSmp._samples();
248                vec tmp = zeros ( dim );
249                for ( int i = 0; i < Npoints; i++ ) {
250                        tmp += w ( i ) * S ( i );
251                }
252                return tmp;
253        }
254        mat covariance() const {
255                const vec &w = eSmp._w();
256                const Array<vec> &S = eSmp._samples();
257
258                vec mea = mean();
259
260//                      cout << sum (w) << "," << w*w << endl;
261
262                mat Tmp = zeros ( dim, dim );
263                for ( int i = 0; i < Npoints; i++ ) {
264                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
265                }
266                return Tmp - outer_product ( mea, mea );
267        }
268        vec variance() const {
269                const vec &w = eSmp._w();
270                const Array<vec> &S = eSmp._samples();
271
272                vec tmp = zeros ( dim );
273                for ( int i = 0; i < Nsources; i++ ) {
274                        tmp += w ( i ) * pow ( S ( i ), 2 );
275                }
276                return tmp - pow ( mean(), 2 );
277        }
278        //!@}
279
280        //! \name Access to attributes
281        //! @{
282
283        //! Access function
284        eEmp& _Smp() {
285                return eSmp;
286        }
287
288        //! load from setting
289        void from_setting ( const Setting& set ) {
290                // get support
291                // find which method to use
292                string meth_str;
293                UI::get<string> ( meth_str, set, "method", UI::compulsory );
294                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
295                        set_method ( ARITHMETIC );
296                else {
297                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
298                                set_method ( GEOMETRIC );
299                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
300                                set_method ( LOGNORMAL );
301                                set.lookupValue ( "beta", beta );
302                        }
303                }
304                string dbg_file;
305                if ( UI::get ( dbg_file, set, "dbg_file" ) )
306                        set_debug_file ( dbg_file );
307                //validate() - not used
308        }
309
310        void validate() {
311                it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
312                it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
313                it_assert ( isnamed(), "mergers must be named" );
314        }
315        //!@}
316};
317UIREGISTER ( merger_base );
318SHAREDPTR ( merger_base );
319
320//! Merger using importance sampling with mixture proposal density
321class merger_mix : public merger_base {
322protected:
323        //!Internal mixture of EF models
324        MixEF Mix;
325        //!Number of components in a mixture
326        int Ncoms;
327        //! coefficient of resampling [0,1]
328        double effss_coef;
329        //! stop after niter iterations
330        int stop_niter;
331
332        //! default value for Ncoms
333        static const int DFLT_Ncoms;
334        //! default value for efss_coef;
335        static const double DFLT_effss_coef;
336
337public:
338        //!\name Constructors
339        //!@{
340        merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
341
342        merger_mix ( const Array<shared_ptr<mpdf> > &S ):
343                Ncoms(0), effss_coef(0), stop_niter(0) {
344                set_sources ( S );
345        }
346
347        //! Set sources and prepare all internal structures
348        void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
349                merger_base::set_sources ( S );
350                Nsources = S.length();
351        }
352
353        //! Set internal parameters used in approximation
354        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
355                Ncoms = Ncoms0;
356                effss_coef = effss_coef0;
357        }
358        //!@}
359
360        //! \name Mathematical operations
361        //!@{
362
363        //!Merge values using mixture approximation
364        void merge ();
365
366        //! sample from the approximating mixture
367        vec sample () const {
368                return Mix.posterior().sample();
369        }
370        //! loglikelihood computed on mixture models
371        double evallog ( const vec &dt ) const {
372                vec dtf = ones ( dt.length() + 1 );
373                dtf.set_subvector ( 0, dt );
374                return Mix.logpred ( dtf );
375        }
376        //!@}
377
378        //!\name Access functions
379        //!@{
380//! Access function
381        MixEF& _Mix() {
382                return Mix;
383        }
384        //! Access function
385        emix* proposal() {
386                emix* tmp = Mix.epredictor();
387                tmp->set_rv ( rv );
388                return tmp;
389        }
390        //! from_settings
391        void from_setting ( const Setting& set ) {
392                merger_base::from_setting ( set );
393                set.lookupValue ( "ncoms", Ncoms );
394                set.lookupValue ( "effss_coef", effss_coef );
395                set.lookupValue ( "stop_niter", stop_niter );
396        }
397
398        //! @}
399
400};
401UIREGISTER ( merger_mix );
402SHAREDPTR ( merger_mix );
403
404}
405
406#endif // MER_H
Note: See TracBrowser for help on using the browser.