root/library/bdm/stat/merger.h @ 488

Revision 488, 10.2 kB (checked in by smidl, 15 years ago)

changes in mpdf -> compile OK, broken tests!

  • Property svn:eol-style set to native
RevLine 
[176]1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef MERGER_H
14#define MERGER_H
[176]15
[262]16
[384]17#include "../estim/mixtures.h"
[176]18
[477]19namespace bdm {
[384]20using std::string;
[176]21
[384]22//!Merging methods
23enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
[176]24
[384]25/*!
26@brief Base class for general combination of pdfs on discrete support
[176]27
[384]28Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
[198]29
[384]30The merged pdfs are expected to be of the form:
31 \f[ f(x_i|y_i),  i=1..n \f]
32where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
33Note that all variables will be joined.
[205]34
[384]35As a result of this feature, each source must be extended to common support
36\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
37where \f$ z_i \f$ accumulate variables that were not in the original source.
38These extensions are calculated on-the-fly.
[299]39
[384]40However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
41For merging of more general cases, use offsprings merger_mix and merger_grid.
42*/
43
[477]44class merger_base : public compositepdf, public epdf {
45protected:
46        //! Data link for each mpdf in mpdfs
47        Array<datalink_m2e*> dls;
48        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
49        Array<RV> rvzs;
50        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
51        Array<datalink_m2e*> zdls;
52        //! number of support points
53        int Npoints;
54        //! number of sources
55        int Nsources;
[384]56
[477]57        //! switch of the methoh used for merging
58        MERGER_METHOD METHOD;
59        //! Default for METHOD
60        static const MERGER_METHOD DFLT_METHOD;
[384]61
[477]62        //!Prior on the log-normal merging model
63        double beta;
64        //! default for beta
65        static const double DFLT_beta;
[384]66
[477]67        //! Projection to empirical density (could also be piece-wise linear)
68        eEmp eSmp;
[384]69
[477]70        //! debug or not debug
71        bool DBG;
[423]72
[477]73        //! debugging file
74        it_file* dbg_file;
75public:
76        //! \name Constructors
77        //! @{
[423]78
[477]79        //!Empty constructor
80        merger_base () : compositepdf() {
81                DBG = false;
82                dbg_file = NULL;
83        }
[384]84
[477]85        //!Constructor from sources
86        merger_base ( const Array<mpdf*> &S, bool own = false );
[384]87
[477]88        //! Function setting the main internal structures
89        void set_sources ( const Array<mpdf*> &Sources, bool own ) {
90                compositepdf::set_elements ( Sources, own );
91                Nsources = mpdfs.length();
92                //set sizes
93                dls.set_size ( Sources.length() );
94                rvzs.set_size ( Sources.length() );
95                zdls.set_size ( Sources.length() );
[384]96
[477]97                rv = getrv ( /* checkoverlap = */ false );
98                RV rvc;
99                setrvc ( rv, rvc );  // Extend rv by rvc!
100                // join rv and rvc - see descriprion
101                rv.add ( rvc );
102                // get dimension
103                dim = rv._dsize();
104
105                // create links between sources and common rv
106                RV xytmp;
107                for ( int i = 0; i < mpdfs.length(); i++ ) {
108                        //Establich connection between mpdfs and merger
109                        dls ( i ) = new datalink_m2e;
110                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
111
112                        // find out what is missing in each mpdf
113                        xytmp = mpdfs ( i )->_rv();
114                        xytmp.add ( mpdfs ( i )->_rvc() );
115                        // z_i = common_rv-xy
116                        rvzs ( i ) = rv.subt ( xytmp );
117                        //establish connection between extension (z_i|x,y)s and common rv
118                        zdls ( i ) = new datalink_m2e;
119                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
120                };
121        }
122        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
123        void set_support ( const Array<vec> &XYZ, const int dimsize ) {
124                set_support ( XYZ, dimsize*ones_i ( XYZ.length() ) );
125        }
126        //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
127        void set_support ( const Array<vec> &XYZ, const ivec &gridsize ) {
128                int dim = XYZ.length();  //check with internal dim!!
129                Npoints = prod ( gridsize );
130                eSmp.set_parameters ( Npoints, false );
131                Array<vec> &samples = eSmp._samples();
132                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
133                //set samples
134                ivec ind = zeros_i ( dim );    //indeces of dimensions in for cycle;
135                vec smpi ( dim );          // ith sample
136                vec steps = zeros ( dim );        // ith sample
137                // first corner
138                for ( int j = 0; j < dim; j++ ) {
139                        smpi ( j ) = XYZ ( j ) ( 0 ); /* beginning of the interval*/
140                        it_assert ( gridsize ( j ) != 0.0, "Zeros in gridsize!" );
141                        steps ( j ) = ( XYZ ( j ) ( 1 ) - smpi ( j ) ) / gridsize ( j );
[384]142                }
[477]143                // fill samples
144                for ( int i = 0; i < Npoints; i++ ) {
145                        // copy
146                        samples ( i ) = smpi;
147                        // go through all dimensions
148                        for ( int j = 0; j < dim; j++ ) {
149                                if ( ind ( j ) == gridsize ( j ) - 1 ) { //j-th index is full
150                                        ind ( j ) = 0; //shift back
151                                        smpi ( j ) = XYZ ( j ) ( 0 );
152
153                                        if ( i < Npoints - 1 ) {
154                                                ind ( j + 1 ) ++; //increase the next dimension;
155                                                smpi ( j + 1 ) += steps ( j + 1 );
[388]156                                                break;
157                                        }
[477]158
159                                } else {
160                                        ind ( j ) ++;
161                                        smpi ( j ) += steps ( j );
162                                        break;
[388]163                                }
164                        }
165                }
[477]166        }
167        //! set debug file
168        void set_debug_file ( const string fname ) {
169                if ( DBG ) delete dbg_file;
170                dbg_file = new it_file ( fname );
171                if ( dbg_file ) DBG = true;
172        }
173        //! Set internal parameters used in approximation
174        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
175                METHOD = MTH;
176                beta = beta0;
177        }
178        //! Set support points from a pdf by drawing N samples
179        void set_support ( const epdf &overall, int N ) {
[488]180                eSmp.set_statistics ( overall, N );
[477]181                Npoints = N;
182        }
183
184        //! Destructor
185        virtual ~merger_base() {
186                for ( int i = 0; i < Nsources; i++ ) {
187                        delete dls ( i );
188                        delete zdls ( i );
[384]189                }
[477]190                if ( DBG ) delete dbg_file;
191        };
192        //!@}
[388]193
[477]194        //! \name Mathematical operations
195        //!@{
[388]196
[477]197        //!Merge given sources in given points
198        virtual void merge () {
199                validate();
[388]200
[477]201                //check if sources overlap:
202                bool OK = true;
203                for ( int i = 0; i < mpdfs.length(); i++ ) {
204                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
205                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
206                }
[388]207
[477]208                if ( OK ) {
209                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
[384]210
[477]211                        vec emptyvec ( 0 );
212                        for ( int i = 0; i < mpdfs.length(); i++ ) {
213                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
214                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
[299]215                                }
216                        }
[176]217
[477]218                        vec w_nn = merge_points ( lW );
219                        vec wtmp = exp ( w_nn - max ( w_nn ) );
220                        //renormalize
221                        eSmp._w() = wtmp / sum ( wtmp );
222                } else {
223                        it_error ( "Sources are not compatible - use merger_mix" );
224                }
225        };
[384]226
[388]227
[477]228        //! Merge log-likelihood values in points using method specified by parameter METHOD
229        vec merge_points ( mat &lW );
[388]230
[477]231
232        //! sample from merged density
[192]233//! weight w is a
[477]234        vec mean() const {
235                const Vec<double> &w = eSmp._w();
236                const Array<vec> &S = eSmp._samples();
237                vec tmp = zeros ( dim );
238                for ( int i = 0; i < Npoints; i++ ) {
239                        tmp += w ( i ) * S ( i );
[384]240                }
[477]241                return tmp;
242        }
243        mat covariance() const {
244                const vec &w = eSmp._w();
245                const Array<vec> &S = eSmp._samples();
[299]246
[477]247                vec mea = mean();
[299]248
[404]249//                      cout << sum (w) << "," << w*w << endl;
[299]250
[477]251                mat Tmp = zeros ( dim, dim );
252                for ( int i = 0; i < Npoints; i++ ) {
253                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
[384]254                }
[477]255                return Tmp - outer_product ( mea, mea );
256        }
257        vec variance() const {
258                const vec &w = eSmp._w();
259                const Array<vec> &S = eSmp._samples();
[299]260
[477]261                vec tmp = zeros ( dim );
262                for ( int i = 0; i < Nsources; i++ ) {
263                        tmp += w ( i ) * pow ( S ( i ), 2 );
[384]264                }
[477]265                return tmp - pow ( mean(), 2 );
266        }
267        //!@}
[192]268
[477]269        //! \name Access to attributes
270        //! @{
[384]271
[477]272        //! Access function
273        eEmp& _Smp() {
274                return eSmp;
275        }
[388]276
[477]277        //! load from setting
278        void from_setting ( const Setting& set ) {
279                // get support
280                // find which method to use
281                string meth_str;
282                UI::get<string> ( meth_str, set, "method", UI::compulsory );
283                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
284                        set_method ( ARITHMETIC );
285                else {
286                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
287                                set_method ( GEOMETRIC );
288                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
289                                set_method ( LOGNORMAL );
290                                set.lookupValue ( "beta", beta );
[388]291                        }
292                }
[477]293                string dbg_file;
294                if ( UI::get ( dbg_file, set, "dbg_file" ) )
295                        set_debug_file ( dbg_file );
296                //validate() - not used
297        }
[388]298
[477]299        void validate() {
300                it_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
301                it_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
302                it_assert ( isnamed(), "mergers must be named" );
303        }
304        //!@}
[384]305};
[477]306UIREGISTER ( merger_base );
[384]307
[477]308class merger_mix : public merger_base {
309protected:
310        //!Internal mixture of EF models
311        MixEF Mix;
312        //!Number of components in a mixture
313        int Ncoms;
314        //! coefficient of resampling [0,1]
315        double effss_coef;
316        //! stop after niter iterations
317        int stop_niter;
[384]318
[477]319        //! default value for Ncoms
320        static const int DFLT_Ncoms;
321        //! default value for efss_coef;
322        static const double DFLT_effss_coef;
[388]323
[477]324public:
325        //!\name Constructors
326        //!@{
327        merger_mix () {};
328        merger_mix ( const Array<mpdf*> &S, bool own = false ) {
329                set_sources ( S, own );
330        };
331        //! Set sources and prepare all internal structures
332        void set_sources ( const Array<mpdf*> &S, bool own ) {
333                merger_base::set_sources ( S, own );
334                Nsources = S.length();
335        }
336        //! Set internal parameters used in approximation
337        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
338                Ncoms = Ncoms0;
339                effss_coef = effss_coef0;
340        }
341        //!@}
[388]342
[477]343        //! \name Mathematical operations
344        //!@{
[384]345
[477]346        //!Merge values using mixture approximation
347        void merge ();
[384]348
[477]349        //! sample from the approximating mixture
350        vec sample () const {
351                return Mix.posterior().sample();
352        }
353        //! loglikelihood computed on mixture models
354        double evallog ( const vec &dt ) const {
355                vec dtf = ones ( dt.length() + 1 );
356                dtf.set_subvector ( 0, dt );
357                return Mix.logpred ( dtf );
358        }
359        //!@}
360
361        //!\name Access functions
362        //!@{
[192]363//! Access function
[477]364        MixEF& _Mix() {
365                return Mix;
366        }
367        //! Access function
368        emix* proposal() {
369                emix* tmp = Mix.epredictor();
370                tmp->set_rv ( rv );
371                return tmp;
372        }
373        //! from_settings
374        void from_setting ( const Setting& set ) {
375                merger_base::from_setting ( set );
376                set.lookupValue ( "ncoms", Ncoms );
377                set.lookupValue ( "effss_coef", effss_coef );
378                set.lookupValue ( "stop_niter", stop_niter );
379        }
[388]380
[477]381        //! @}
382
[384]383};
[477]384UIREGISTER ( merger_mix );
[176]385
[254]386}
[176]387
388#endif // MER_H
Note: See TracBrowser for help on using the browser.