root/library/bdm/stat/merger.h @ 569

Revision 569, 9.4 kB (checked in by smidl, 15 years ago)

new object discrete_support, merger adapted to accept this input as well

  • Property svn:eol-style set to native
RevLine 
[176]1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef MERGER_H
14#define MERGER_H
[176]15
[262]16
[384]17#include "../estim/mixtures.h"
[556]18#include "discrete.h"
[176]19
[477]20namespace bdm {
[384]21using std::string;
[176]22
[384]23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
[176]25
[384]26/*!
27@brief Base class for general combination of pdfs on discrete support
[176]28
[384]29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
[198]30
[384]31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
[205]35
[384]36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
[299]40
[384]41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
[507]45class merger_base : public epdf {
[477]46protected:
[507]47        //! Elements of composition
48        Array<shared_ptr<mpdf> > mpdfs;
49
[477]50        //! Data link for each mpdf in mpdfs
51        Array<datalink_m2e*> dls;
[507]52
[477]53        //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
54        Array<RV> rvzs;
[507]55
[477]56        //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
57        Array<datalink_m2e*> zdls;
[507]58
[477]59        //! number of support points
60        int Npoints;
[507]61
[477]62        //! number of sources
63        int Nsources;
[384]64
[477]65        //! switch of the methoh used for merging
66        MERGER_METHOD METHOD;
67        //! Default for METHOD
68        static const MERGER_METHOD DFLT_METHOD;
[384]69
[477]70        //!Prior on the log-normal merging model
71        double beta;
72        //! default for beta
73        static const double DFLT_beta;
[384]74
[477]75        //! Projection to empirical density (could also be piece-wise linear)
76        eEmp eSmp;
[384]77
[477]78        //! debug or not debug
79        bool DBG;
[423]80
[477]81        //! debugging file
82        it_file* dbg_file;
83public:
84        //! \name Constructors
85        //! @{
[423]86
[507]87        //! Default constructor
88        merger_base () : Npoints(0), Nsources(0), DBG(false), dbg_file(0) {
[477]89        }
[384]90
[477]91        //!Constructor from sources
[507]92        merger_base ( const Array<shared_ptr<mpdf> > &S );
[384]93
[477]94        //! Function setting the main internal structures
[507]95        void set_sources ( const Array<shared_ptr<mpdf> > &Sources ) {
96                mpdfs = Sources;
[477]97                Nsources = mpdfs.length();
98                //set sizes
99                dls.set_size ( Sources.length() );
100                rvzs.set_size ( Sources.length() );
101                zdls.set_size ( Sources.length() );
[384]102
[507]103                rv = get_composite_rv ( mpdfs, /* checkoverlap = */ false );
104
[477]105                RV rvc;
[507]106                // Extend rv by rvc!
107                for ( int i = 0; i < mpdfs.length(); i++ ) {
108                        RV rvx = mpdfs ( i )->_rvc().subt ( rv );
109                        rvc.add ( rvx ); // add rv to common rvc
110                }
111
[477]112                // join rv and rvc - see descriprion
113                rv.add ( rvc );
114                // get dimension
115                dim = rv._dsize();
116
117                // create links between sources and common rv
118                RV xytmp;
119                for ( int i = 0; i < mpdfs.length(); i++ ) {
120                        //Establich connection between mpdfs and merger
121                        dls ( i ) = new datalink_m2e;
122                        dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv );
123
124                        // find out what is missing in each mpdf
125                        xytmp = mpdfs ( i )->_rv();
126                        xytmp.add ( mpdfs ( i )->_rvc() );
127                        // z_i = common_rv-xy
128                        rvzs ( i ) = rv.subt ( xytmp );
129                        //establish connection between extension (z_i|x,y)s and common rv
130                        zdls ( i ) = new datalink_m2e;
131                        zdls ( i )->set_connection ( rvzs ( i ), xytmp, rv ) ;
132                };
133        }
[556]134        //! Set support points from rectangular grid
135        void set_support ( rectangular_support &Sup) {
136                Npoints = Sup.points();
[477]137                eSmp.set_parameters ( Npoints, false );
138                Array<vec> &samples = eSmp._samples();
139                eSmp._w() = ones ( Npoints ) / Npoints; //unifrom size of bins
140                //set samples
[556]141                samples(0)=Sup.first_vec();
142                for (int j=1; j < Npoints; j++ ) {
143                        samples ( j ) = Sup.next_vec();
[384]144                }
[477]145        }
[569]146        //! Set support points from dicrete grid
147        void set_support ( discrete_support &Sup) {
148                Npoints = Sup.points();
149                eSmp.set_parameters(Sup._Spoints());
150        }
[477]151        //! set debug file
152        void set_debug_file ( const string fname ) {
153                if ( DBG ) delete dbg_file;
154                dbg_file = new it_file ( fname );
155                if ( dbg_file ) DBG = true;
156        }
157        //! Set internal parameters used in approximation
158        void set_method ( MERGER_METHOD MTH = DFLT_METHOD, double beta0 = DFLT_beta ) {
159                METHOD = MTH;
160                beta = beta0;
161        }
162        //! Set support points from a pdf by drawing N samples
163        void set_support ( const epdf &overall, int N ) {
[488]164                eSmp.set_statistics ( overall, N );
[477]165                Npoints = N;
166        }
167
168        //! Destructor
169        virtual ~merger_base() {
170                for ( int i = 0; i < Nsources; i++ ) {
171                        delete dls ( i );
172                        delete zdls ( i );
[384]173                }
[477]174                if ( DBG ) delete dbg_file;
175        };
176        //!@}
[388]177
[477]178        //! \name Mathematical operations
179        //!@{
[388]180
[477]181        //!Merge given sources in given points
182        virtual void merge () {
183                validate();
[388]184
[477]185                //check if sources overlap:
186                bool OK = true;
187                for ( int i = 0; i < mpdfs.length(); i++ ) {
188                        OK &= ( rvzs ( i )._dsize() == 0 ); // z_i is empty
189                        OK &= ( mpdfs ( i )->_rvc()._dsize() == 0 ); // y_i is empty
190                }
[388]191
[477]192                if ( OK ) {
193                        mat lW = zeros ( mpdfs.length(), eSmp._w().length() );
[384]194
[477]195                        vec emptyvec ( 0 );
196                        for ( int i = 0; i < mpdfs.length(); i++ ) {
197                                for ( int j = 0; j < eSmp._w().length(); j++ ) {
198                                        lW ( i, j ) = mpdfs ( i )->evallogcond ( eSmp._samples() ( j ), emptyvec );
[299]199                                }
200                        }
[176]201
[477]202                        vec w_nn = merge_points ( lW );
203                        vec wtmp = exp ( w_nn - max ( w_nn ) );
204                        //renormalize
205                        eSmp._w() = wtmp / sum ( wtmp );
206                } else {
[565]207                        bdm_error ( "Sources are not compatible - use merger_mix" );
[477]208                }
209        };
[384]210
[388]211
[477]212        //! Merge log-likelihood values in points using method specified by parameter METHOD
213        vec merge_points ( mat &lW );
[388]214
[477]215
216        //! sample from merged density
[192]217//! weight w is a
[477]218        vec mean() const {
219                const Vec<double> &w = eSmp._w();
220                const Array<vec> &S = eSmp._samples();
221                vec tmp = zeros ( dim );
222                for ( int i = 0; i < Npoints; i++ ) {
223                        tmp += w ( i ) * S ( i );
[384]224                }
[477]225                return tmp;
226        }
227        mat covariance() const {
228                const vec &w = eSmp._w();
229                const Array<vec> &S = eSmp._samples();
[299]230
[477]231                vec mea = mean();
[299]232
[404]233//                      cout << sum (w) << "," << w*w << endl;
[299]234
[477]235                mat Tmp = zeros ( dim, dim );
236                for ( int i = 0; i < Npoints; i++ ) {
237                        Tmp += w ( i ) * outer_product ( S ( i ), S ( i ) );
[384]238                }
[477]239                return Tmp - outer_product ( mea, mea );
240        }
241        vec variance() const {
242                const vec &w = eSmp._w();
243                const Array<vec> &S = eSmp._samples();
[299]244
[477]245                vec tmp = zeros ( dim );
246                for ( int i = 0; i < Nsources; i++ ) {
247                        tmp += w ( i ) * pow ( S ( i ), 2 );
[384]248                }
[477]249                return tmp - pow ( mean(), 2 );
250        }
251        //!@}
[192]252
[477]253        //! \name Access to attributes
254        //! @{
[384]255
[477]256        //! Access function
257        eEmp& _Smp() {
258                return eSmp;
259        }
[388]260
[477]261        //! load from setting
262        void from_setting ( const Setting& set ) {
263                // get support
264                // find which method to use
265                string meth_str;
266                UI::get<string> ( meth_str, set, "method", UI::compulsory );
267                if ( !strcmp ( meth_str.c_str(), "arithmetic" ) )
268                        set_method ( ARITHMETIC );
269                else {
270                        if ( !strcmp ( meth_str.c_str(), "geometric" ) )
271                                set_method ( GEOMETRIC );
272                        else if ( !strcmp ( meth_str.c_str(), "lognormal" ) ) {
273                                set_method ( LOGNORMAL );
274                                set.lookupValue ( "beta", beta );
[388]275                        }
276                }
[477]277                string dbg_file;
278                if ( UI::get ( dbg_file, set, "dbg_file" ) )
279                        set_debug_file ( dbg_file );
280                //validate() - not used
281        }
[388]282
[477]283        void validate() {
[565]284                bdm_assert ( eSmp._w().length() > 0, "Empty support, use set_support()." );
285                bdm_assert ( dim == eSmp._samples() ( 0 ).length(), "Support points and rv are not compatible!" );
286                bdm_assert ( isnamed(), "mergers must be named" );
[477]287        }
288        //!@}
[384]289};
[477]290UIREGISTER ( merger_base );
[529]291SHAREDPTR ( merger_base );
[384]292
[536]293//! Merger using importance sampling with mixture proposal density
[477]294class merger_mix : public merger_base {
295protected:
296        //!Internal mixture of EF models
297        MixEF Mix;
298        //!Number of components in a mixture
299        int Ncoms;
300        //! coefficient of resampling [0,1]
301        double effss_coef;
302        //! stop after niter iterations
303        int stop_niter;
[384]304
[477]305        //! default value for Ncoms
306        static const int DFLT_Ncoms;
307        //! default value for efss_coef;
308        static const double DFLT_effss_coef;
[388]309
[477]310public:
311        //!\name Constructors
312        //!@{
[507]313        merger_mix ():Ncoms(0), effss_coef(0), stop_niter(0) { }
314
315        merger_mix ( const Array<shared_ptr<mpdf> > &S ):
316                Ncoms(0), effss_coef(0), stop_niter(0) {
317                set_sources ( S );
318        }
319
[477]320        //! Set sources and prepare all internal structures
[507]321        void set_sources ( const Array<shared_ptr<mpdf> > &S ) {
322                merger_base::set_sources ( S );
[477]323                Nsources = S.length();
324        }
[507]325
[477]326        //! Set internal parameters used in approximation
327        void set_parameters ( int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef ) {
328                Ncoms = Ncoms0;
329                effss_coef = effss_coef0;
330        }
331        //!@}
[388]332
[477]333        //! \name Mathematical operations
334        //!@{
[384]335
[477]336        //!Merge values using mixture approximation
337        void merge ();
[384]338
[477]339        //! sample from the approximating mixture
340        vec sample () const {
341                return Mix.posterior().sample();
342        }
343        //! loglikelihood computed on mixture models
344        double evallog ( const vec &dt ) const {
345                vec dtf = ones ( dt.length() + 1 );
346                dtf.set_subvector ( 0, dt );
347                return Mix.logpred ( dtf );
348        }
349        //!@}
350
351        //!\name Access functions
352        //!@{
[192]353//! Access function
[477]354        MixEF& _Mix() {
355                return Mix;
356        }
357        //! Access function
358        emix* proposal() {
359                emix* tmp = Mix.epredictor();
360                tmp->set_rv ( rv );
361                return tmp;
362        }
363        //! from_settings
364        void from_setting ( const Setting& set ) {
365                merger_base::from_setting ( set );
366                set.lookupValue ( "ncoms", Ncoms );
367                set.lookupValue ( "effss_coef", effss_coef );
368                set.lookupValue ( "stop_niter", stop_niter );
369        }
[388]370
[477]371        //! @}
372
[384]373};
[477]374UIREGISTER ( merger_mix );
[529]375SHAREDPTR ( merger_mix );
[176]376
[254]377}
[176]378
379#endif // MER_H
Note: See TracBrowser for help on using the browser.