root/library/bdm/stat/merger.h @ 507

Revision 507, 10.4 kB (checked in by vbarta, 15 years ago)

removed class compositepdf; keeping mpdfs of mprod and merger_base in shared pointers

  • Property svn:eol-style set to native
RevLine 
[176]1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef MERGER_H
14#define MERGER_H
[176]15
[262]16
[384]17#include "../estim/mixtures.h"
[176]18
[477]19namespace bdm {
[384]20using std::string;
[176]21
[384]22//!Merging methods
23enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
[176]24
[384]25/*!
26@brief Base class for general combination of pdfs on discrete support
[176]27
[384]28Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
[198]29
[384]30The merged pdfs are expected to be of the form:
31 \f[ f(x_i|y_i),  i=1..n \f]
32where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
33Note that all variables will be joined.
[205]34
[384]35As a result of this feature, each source must be extended to common support
36\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
37where \f$ z_i \f$ accumulate variables that were not in the original source.
38These extensions are calculated on-the-fly.
[299]39
[384]40However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
41For merging of more general cases, use offsprings merger_mix and merger_grid.
42*/
43
[507]44class merger_base : public epdf {
[477]45protected:
[507]46        //! Elements of composition
47        Array<shared_ptr<mpdf> > mpdfs;
48
[477]49        //! Data link for each mpdf in mpdfs
50        Array<datalink_m2e*> dls;
[507]51
[477]52        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
53        Array<RV> rvzs;
[507]54
[477]55        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
56        Array<datalink_m2e*> zdls;
[507]57
[477]58        //! number of support points
59        int Npoints;
[507]60
[477]61        //! number of sources
62        int Nsources;
[384]63
[477]64        //! switch of the methoh used for merging
65        MERGER_METHOD METHOD;
66        //! Default for METHOD
67        static const MERGER_METHOD DFLT_METHOD;
[384]68
[477]69        //!Prior on the log-normal merging model
70        double beta;
71        //! default for beta
72        static const double DFLT_beta;
[384]73
[477]74        //! Projection to empirical density (could also be piece-wise linear)
75        eEmp eSmp;
[384]76
[477]77        //! debug or not debug
78        bool DBG;
[423]79
[477]80        //! debugging file
81        it_file* dbg_file;
82public:
83        //! \name Constructors
84        //! @{
[423]85
[507]86        //! Default constructor
87        merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
[477]88        }
[384]89
[477]90        //!Constructor from sources
[507]91        merger_base ( const Array<shared_ptr<mpdf> > &S );
[384]92
[477]93        //! Function setting the main internal structures
[507]94        void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
95                mpdfs = Sources;
[477]96                Nsources = mpdfs.length();
97                //set sizes
98                dls.set_size ( Sources.length() );
99                rvzs.set_size ( Sources.length() );
100                zdls.set_size ( Sources.length() );
[384]101
[507]102                rv = get_composite_rv ( mpdfs, /* checkoverlap = */ false );
103
[477]104                RV rvc;
[507]105                // Extend rv by rvc!
106                for ( int i = 0; i < mpdfs.length(); i++ ) {
107                        RV rvx = mpdfs ( i )->_rvc().subt ( rv );
108                        rvc.add ( rvx ); // add rv to common rvc
109                }
110
[477]111                // join rv and rvc - see descriprion
112                rv.add ( rvc );
113                // get dimension
114                dim = rv._dsize();
115
116                // create links between sources and common rv
117                RV xytmp;
118                for ( int i = 0; i < mpdfs.length(); i++ ) {
119                        //Establich connection between mpdfs and merger
120                        dls ( i ) = new datalink_m2e;
121                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
122
123                        // find out what is missing in each mpdf
124                        xytmp = mpdfs ( i )->_rv();
125                        xytmp.add ( mpdfs ( i )->_rvc() );
126                        // z_i = common_rv-xy
127                        rvzs ( i ) = rv.subt ( xytmp );
128                        //establish connection between extension (z_i|x,y)s and common rv
129                        zdls ( i ) = new datalink_m2e;
130                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
131                };
132        }
133        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
134        void set_support ( const Array<vec> &XYZ, const int dimsize ) {
135                set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
136        }
137        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
138        void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
139                int dim = XYZ.length();  //check with internal dim!!
140                Npoints = prod ( gridsize );
141                eSmp.set_parameters ( Npoints, false );
142                Array<vec> &samples = eSmp._samples();
143                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
144                //set samples
145                ivec ind = zeros_i ( dim );    //indeces of dimensions in for cycle;
146                vec smpi ( dim );          // ith sample
147                vec steps = zeros ( dim );        // ith sample
148                // first corner
149                for ( int j = 0; j < dim; j++ ) {
150                        smpi ( j ) = XYZ ( j ) ( 0 ); /* beginning of the interval*/
151                        it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
152                        steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
[384]153                }
[477]154                // fill samples
155                for ( int i = 0; i < Npoints; i++ ) {
156                        // copy
157                        samples ( i ) = smpi;
158                        // go through all dimensions
159                        for ( int j = 0; j < dim; j++ ) {
160                                if ( ind ( j ) == gridsize ( j ) - 1 ) { //j-th index is full
161                                        ind ( j ) = 0; //shift back
162                                        smpi ( j ) = XYZ ( j ) ( 0 );
163
164                                        if ( i < Npoints - 1 ) {
165                                                ind ( j + 1 ) ++; //increase the next dimension;
166                                                smpi ( j + 1 ) += steps ( j + 1 );
[388]167                                                break;
168                                        }
[477]169
170                                } else {
171                                        ind ( j ) ++;
172                                        smpi ( j ) += steps ( j );
173                                        break;
[388]174                                }
175                        }
176                }
[477]177        }
178        //! set debug file
179        void set_debug_file ( const string fname ) {
180                if ( DBG ) delete dbg_file;
181                dbg_file = new it_file ( fname );
182                if ( dbg_file ) DBG = true;
183        }
184        //! Set internal parameters used in approximation
185        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
186                METHOD = MTH;
187                beta = beta0;
188        }
189        //! Set support points from a pdf by drawing N samples
190        void set_support ( const epdf &overall, int N ) {
[488]191                eSmp.set_statistics ( overall, N );
[477]192                Npoints = N;
193        }
194
195        //! Destructor
196        virtual ~merger_base() {
197                for ( int i = 0; i < Nsources; i++ ) {
198                        delete dls ( i );
199                        delete zdls ( i );
[384]200                }
[477]201                if ( DBG ) delete dbg_file;
202        };
203        //!@}
[388]204
[477]205        //! \name Mathematical operations
206        //!@{
[388]207
[477]208        //!Merge given sources in given points
209        virtual void merge () {
210                validate();
[388]211
[477]212                //check if sources overlap:
213                bool OK = true;
214                for ( int i = 0; i < mpdfs.length(); i++ ) {
215                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
216                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
217                }
[388]218
[477]219                if ( OK ) {
220                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
[384]221
[477]222                        vec emptyvec ( 0 );
223                        for ( int i = 0; i < mpdfs.length(); i++ ) {
224                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
225                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
[299]226                                }
227                        }
[176]228
[477]229                        vec w_nn = merge_points ( lW );
230                        vec wtmp = exp ( w_nn - max ( w_nn ) );
231                        //renormalize
232                        eSmp._w() = wtmp / sum ( wtmp );
233                } else {
234                        it_error ( "Sources are not compatible - use merger_mix" );
235                }
236        };
[384]237
[388]238
[477]239        //! Merge log-likelihood values in points using method specified by parameter METHOD
240        vec merge_points ( mat &lW );
[388]241
[477]242
243        //! sample from merged density
[192]244//! weight w is a
[477]245        vec mean() const {
246                const Vec<double> &w = eSmp._w();
247                const Array<vec> &S = eSmp._samples();
248                vec tmp = zeros ( dim );
249                for ( int i = 0; i < Npoints; i++ ) {
250                        tmp += w ( i ) * S ( i );
[384]251                }
[477]252                return tmp;
253        }
254        mat covariance() const {
255                const vec &w = eSmp._w();
256                const Array<vec> &S = eSmp._samples();
[299]257
[477]258                vec mea = mean();
[299]259
[404]260//                      cout << sum (w) << "," << w*w << endl;
[299]261
[477]262                mat Tmp = zeros ( dim, dim );
263                for ( int i = 0; i < Npoints; i++ ) {
264                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
[384]265                }
[477]266                return Tmp - outer_product ( mea, mea );
267        }
268        vec variance() const {
269                const vec &w = eSmp._w();
270                const Array<vec> &S = eSmp._samples();
[299]271
[477]272                vec tmp = zeros ( dim );
273                for ( int i = 0; i < Nsources; i++ ) {
274                        tmp += w ( i ) * pow ( S ( i ), 2 );
[384]275                }
[477]276                return tmp - pow ( mean(), 2 );
277        }
278        //!@}
[192]279
[477]280        //! \name Access to attributes
281        //! @{
[384]282
[477]283        //! Access function
284        eEmp& _Smp() {
285                return eSmp;
286        }
[388]287
[477]288        //! load from setting
289        void from_setting ( const Setting& set ) {
290                // get support
291                // find which method to use
292                string meth_str;
293                UI::get<string> ( meth_str, set, "method", UI::compulsory );
294                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
295                        set_method ( ARITHMETIC );
296                else {
297                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
298                                set_method ( GEOMETRIC );
299                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
300                                set_method ( LOGNORMAL );
301                                set.lookupValue ( "beta", beta );
[388]302                        }
303                }
[477]304                string dbg_file;
305                if ( UI::get ( dbg_file, set, "dbg_file" ) )
306                        set_debug_file ( dbg_file );
307                //validate() - not used
308        }
[388]309
[477]310        void validate() {
311                it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
312                it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
313                it_assert ( isnamed(), "mergers must be named" );
314        }
315        //!@}
[384]316};
[477]317UIREGISTER ( merger_base );
[384]318
[477]319class merger_mix : public merger_base {
320protected:
321        //!Internal mixture of EF models
322        MixEF Mix;
323        //!Number of components in a mixture
324        int Ncoms;
325        //! coefficient of resampling [0,1]
326        double effss_coef;
327        //! stop after niter iterations
328        int stop_niter;
[384]329
[477]330        //! default value for Ncoms
331        static const int DFLT_Ncoms;
332        //! default value for efss_coef;
333        static const double DFLT_effss_coef;
[388]334
[477]335public:
336        //!\name Constructors
337        //!@{
[507]338        merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
339
340        merger_mix ( const Array<shared_ptr<mpdf> > &S ):
341                Ncoms(0), effss_coef(0), stop_niter(0) {
342                set_sources ( S );
343        }
344
[477]345        //! Set sources and prepare all internal structures
[507]346        void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
347                merger_base::set_sources ( S );
[477]348                Nsources = S.length();
349        }
[507]350
[477]351        //! Set internal parameters used in approximation
352        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
353                Ncoms = Ncoms0;
354                effss_coef = effss_coef0;
355        }
356        //!@}
[388]357
[477]358        //! \name Mathematical operations
359        //!@{
[384]360
[477]361        //!Merge values using mixture approximation
362        void merge ();
[384]363
[477]364        //! sample from the approximating mixture
365        vec sample () const {
366                return Mix.posterior().sample();
367        }
368        //! loglikelihood computed on mixture models
369        double evallog ( const vec &dt ) const {
370                vec dtf = ones ( dt.length() + 1 );
371                dtf.set_subvector ( 0, dt );
372                return Mix.logpred ( dtf );
373        }
374        //!@}
375
376        //!\name Access functions
377        //!@{
[192]378//! Access function
[477]379        MixEF& _Mix() {
380                return Mix;
381        }
382        //! Access function
383        emix* proposal() {
384                emix* tmp = Mix.epredictor();
385                tmp->set_rv ( rv );
386                return tmp;
387        }
388        //! from_settings
389        void from_setting ( const Setting& set ) {
390                merger_base::from_setting ( set );
391                set.lookupValue ( "ncoms", Ncoms );
392                set.lookupValue ( "effss_coef", effss_coef );
393                set.lookupValue ( "stop_niter", stop_niter );
394        }
[388]395
[477]396        //! @}
397
[384]398};
[477]399UIREGISTER ( merger_mix );
[176]400
[254]401}
[176]402
403#endif // MER_H
Note: See TracBrowser for help on using the browser.