root/library/bdm/stat/merger.h @ 565

Revision 565, 9.3 kB (checked in by vbarta, 15 years ago)

using own error macros (basically copied from IT++, but never aborting)

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18#include "discrete.h"
19
20namespace bdm {
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public epdf {
46protected:
47        //! Elements of composition
48        Array<shared_ptr<mpdf> > mpdfs;
49
50        //! Data link for each mpdf in mpdfs
51        Array<datalink_m2e*> dls;
52
53        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
54        Array<RV> rvzs;
55
56        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
57        Array<datalink_m2e*> zdls;
58
59        //! number of support points
60        int Npoints;
61
62        //! number of sources
63        int Nsources;
64
65        //! switch of the methoh used for merging
66        MERGER_METHOD METHOD;
67        //! Default for METHOD
68        static const MERGER_METHOD DFLT_METHOD;
69
70        //!Prior on the log-normal merging model
71        double beta;
72        //! default for beta
73        static const double DFLT_beta;
74
75        //! Projection to empirical density (could also be piece-wise linear)
76        eEmp eSmp;
77
78        //! debug or not debug
79        bool DBG;
80
81        //! debugging file
82        it_file* dbg_file;
83public:
84        //! \name Constructors
85        //! @{
86
87        //! Default constructor
88        merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
89        }
90
91        //!Constructor from sources
92        merger_base ( const Array<shared_ptr<mpdf> > &S );
93
94        //! Function setting the main internal structures
95        void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
96                mpdfs = Sources;
97                Nsources = mpdfs.length();
98                //set sizes
99                dls.set_size ( Sources.length() );
100                rvzs.set_size ( Sources.length() );
101                zdls.set_size ( Sources.length() );
102
103                rv = get_composite_rv ( mpdfs, /* checkoverlap = */ false );
104
105                RV rvc;
106                // Extend rv by rvc!
107                for ( int i = 0; i < mpdfs.length(); i++ ) {
108                        RV rvx = mpdfs ( i )->_rvc().subt ( rv );
109                        rvc.add ( rvx ); // add rv to common rvc
110                }
111
112                // join rv and rvc - see descriprion
113                rv.add ( rvc );
114                // get dimension
115                dim = rv._dsize();
116
117                // create links between sources and common rv
118                RV xytmp;
119                for ( int i = 0; i < mpdfs.length(); i++ ) {
120                        //Establich connection between mpdfs and merger
121                        dls ( i ) = new datalink_m2e;
122                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
123
124                        // find out what is missing in each mpdf
125                        xytmp = mpdfs ( i )->_rv();
126                        xytmp.add ( mpdfs ( i )->_rvc() );
127                        // z_i = common_rv-xy
128                        rvzs ( i ) = rv.subt ( xytmp );
129                        //establish connection between extension (z_i|x,y)s and common rv
130                        zdls ( i ) = new datalink_m2e;
131                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
132                };
133        }
134        //! Set support points from rectangular grid
135        void set_support ( rectangular_support &Sup) {
136                Npoints = Sup.points();
137                eSmp.set_parameters ( Npoints, false );
138                Array<vec> &samples = eSmp._samples();
139                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
140                //set samples
141                samples(0)=Sup.first_vec();
142                for (int j=1; j < Npoints; j++ ) {
143                        samples ( j ) = Sup.next_vec();
144                }
145        }
146        //! set debug file
147        void set_debug_file ( const string fname ) {
148                if ( DBG ) delete dbg_file;
149                dbg_file = new it_file ( fname );
150                if ( dbg_file ) DBG = true;
151        }
152        //! Set internal parameters used in approximation
153        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
154                METHOD = MTH;
155                beta = beta0;
156        }
157        //! Set support points from a pdf by drawing N samples
158        void set_support ( const epdf &overall, int N ) {
159                eSmp.set_statistics ( overall, N );
160                Npoints = N;
161        }
162
163        //! Destructor
164        virtual ~merger_base() {
165                for ( int i = 0; i < Nsources; i++ ) {
166                        delete dls ( i );
167                        delete zdls ( i );
168                }
169                if ( DBG ) delete dbg_file;
170        };
171        //!@}
172
173        //! \name Mathematical operations
174        //!@{
175
176        //!Merge given sources in given points
177        virtual void merge () {
178                validate();
179
180                //check if sources overlap:
181                bool OK = true;
182                for ( int i = 0; i < mpdfs.length(); i++ ) {
183                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
184                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
185                }
186
187                if ( OK ) {
188                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
189
190                        vec emptyvec ( 0 );
191                        for ( int i = 0; i < mpdfs.length(); i++ ) {
192                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
193                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
194                                }
195                        }
196
197                        vec w_nn = merge_points ( lW );
198                        vec wtmp = exp ( w_nn - max ( w_nn ) );
199                        //renormalize
200                        eSmp._w() = wtmp / sum ( wtmp );
201                } else {
202                        bdm_error ( "Sources are not compatible - use merger_mix" );
203                }
204        };
205
206
207        //! Merge log-likelihood values in points using method specified by parameter METHOD
208        vec merge_points ( mat &lW );
209
210
211        //! sample from merged density
212//! weight w is a
213        vec mean() const {
214                const Vec<double> &w = eSmp._w();
215                const Array<vec> &S = eSmp._samples();
216                vec tmp = zeros ( dim );
217                for ( int i = 0; i < Npoints; i++ ) {
218                        tmp += w ( i ) * S ( i );
219                }
220                return tmp;
221        }
222        mat covariance() const {
223                const vec &w = eSmp._w();
224                const Array<vec> &S = eSmp._samples();
225
226                vec mea = mean();
227
228//                      cout << sum (w) << "," << w*w << endl;
229
230                mat Tmp = zeros ( dim, dim );
231                for ( int i = 0; i < Npoints; i++ ) {
232                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
233                }
234                return Tmp - outer_product ( mea, mea );
235        }
236        vec variance() const {
237                const vec &w = eSmp._w();
238                const Array<vec> &S = eSmp._samples();
239
240                vec tmp = zeros ( dim );
241                for ( int i = 0; i < Nsources; i++ ) {
242                        tmp += w ( i ) * pow ( S ( i ), 2 );
243                }
244                return tmp - pow ( mean(), 2 );
245        }
246        //!@}
247
248        //! \name Access to attributes
249        //! @{
250
251        //! Access function
252        eEmp& _Smp() {
253                return eSmp;
254        }
255
256        //! load from setting
257        void from_setting ( const Setting& set ) {
258                // get support
259                // find which method to use
260                string meth_str;
261                UI::get<string> ( meth_str, set, "method", UI::compulsory );
262                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
263                        set_method ( ARITHMETIC );
264                else {
265                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
266                                set_method ( GEOMETRIC );
267                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
268                                set_method ( LOGNORMAL );
269                                set.lookupValue ( "beta", beta );
270                        }
271                }
272                string dbg_file;
273                if ( UI::get ( dbg_file, set, "dbg_file" ) )
274                        set_debug_file ( dbg_file );
275                //validate() - not used
276        }
277
278        void validate() {
279                bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
280                bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
281                bdm_assert ( isnamed(), "mergers must be named" );
282        }
283        //!@}
284};
285UIREGISTER ( merger_base );
286SHAREDPTR ( merger_base );
287
288//! Merger using importance sampling with mixture proposal density
289class merger_mix : public merger_base {
290protected:
291        //!Internal mixture of EF models
292        MixEF Mix;
293        //!Number of components in a mixture
294        int Ncoms;
295        //! coefficient of resampling [0,1]
296        double effss_coef;
297        //! stop after niter iterations
298        int stop_niter;
299
300        //! default value for Ncoms
301        static const int DFLT_Ncoms;
302        //! default value for efss_coef;
303        static const double DFLT_effss_coef;
304
305public:
306        //!\name Constructors
307        //!@{
308        merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
309
310        merger_mix ( const Array<shared_ptr<mpdf> > &S ):
311                Ncoms(0), effss_coef(0), stop_niter(0) {
312                set_sources ( S );
313        }
314
315        //! Set sources and prepare all internal structures
316        void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
317                merger_base::set_sources ( S );
318                Nsources = S.length();
319        }
320
321        //! Set internal parameters used in approximation
322        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
323                Ncoms = Ncoms0;
324                effss_coef = effss_coef0;
325        }
326        //!@}
327
328        //! \name Mathematical operations
329        //!@{
330
331        //!Merge values using mixture approximation
332        void merge ();
333
334        //! sample from the approximating mixture
335        vec sample () const {
336                return Mix.posterior().sample();
337        }
338        //! loglikelihood computed on mixture models
339        double evallog ( const vec &dt ) const {
340                vec dtf = ones ( dt.length() + 1 );
341                dtf.set_subvector ( 0, dt );
342                return Mix.logpred ( dtf );
343        }
344        //!@}
345
346        //!\name Access functions
347        //!@{
348//! Access function
349        MixEF& _Mix() {
350                return Mix;
351        }
352        //! Access function
353        emix* proposal() {
354                emix* tmp = Mix.epredictor();
355                tmp->set_rv ( rv );
356                return tmp;
357        }
358        //! from_settings
359        void from_setting ( const Setting& set ) {
360                merger_base::from_setting ( set );
361                set.lookupValue ( "ncoms", Ncoms );
362                set.lookupValue ( "effss_coef", effss_coef );
363                set.lookupValue ( "stop_niter", stop_niter );
364        }
365
366        //! @}
367
368};
369UIREGISTER ( merger_mix );
370SHAREDPTR ( merger_mix );
371
372}
373
374#endif // MER_H
Note: See TracBrowser for help on using the browser.