root/library/bdm/stat/merger.h @ 423

Revision 423, 10.1 kB (checked in by vbarta, 15 years ago)

fixed merger_base constructor to initialize debug fields (still not all fields, though...)

  • Property svn:eol-style set to native
RevLine 
[176]1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef MERGER_H
14#define MERGER_H
[176]15
[262]16
[384]17#include "../estim/mixtures.h"
[176]18
[299]19namespace bdm
20{
[384]21using std::string;
[176]22
[384]23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
[176]25
[384]26/*!
27@brief Base class for general combination of pdfs on discrete support
[176]28
[384]29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
[198]30
[384]31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
[205]35
[384]36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
[299]40
[384]41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
[399]61                //! Default for METHOD
62                static const MERGER_METHOD DFLT_METHOD;
63               
[384]64                //!Prior on the log-normal merging model
65                double beta;
[399]66                //! default for beta
67                static const double DFLT_beta;
68               
[384]69                //! Projection to empirical density (could also be piece-wise linear)
70                eEmp eSmp;
71
72                //! debug or not debug
73                bool DBG;
74
75                //! debugging file
76                it_file* dbg_file;
77        public:
78                //! \name Constructors
79                //! @{
80
81                //!Empty constructor
[423]82                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;}
83
[384]84                //!Constructor from sources
[423]85                merger_base (const Array<mpdf*> &S, bool own=false);
86
[388]87                //! Function setting the main internal structures
88                void set_sources (const Array<mpdf*> &Sources, bool own) {
89                        compositepdf::set_elements (Sources,own);
[392]90                        Nsources=mpdfs.length();
[384]91                        //set sizes
92                        dls.set_size (Sources.length());
93                        rvzs.set_size (Sources.length());
94                        zdls.set_size (Sources.length());
95
96                        rv = getrv (/* checkoverlap = */ false);
97                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
98                        // join rv and rvc - see descriprion
99                        rv.add (rvc);
100                        // get dimension
101                        dim = rv._dsize();
102
103                        // create links between sources and common rv
104                        RV xytmp;
105                        for (int i = 0;i < mpdfs.length();i++) {
106                                //Establich connection between mpdfs and merger
107                                dls (i) = new datalink_m2e;
108                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
109
110                                // find out what is missing in each mpdf
111                                xytmp = mpdfs (i)->_rv();
112                                xytmp.add (mpdfs (i)->_rvc());
113                                // z_i = common_rv-xy
114                                rvzs (i) = rv.subt (xytmp);
115                                //establish connection between extension (z_i|x,y)s and common rv
116                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
117                        };
118                }
[388]119                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
120                void set_support (const Array<vec> &XYZ, const int dimsize) {
121                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
122                }
123                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
124                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
125                        int dim = XYZ.length();  //check with internal dim!!
126                        Npoints = prod (gridsize); 
127                        eSmp.set_parameters (Npoints, false);
128                        Array<vec> &samples = eSmp._samples();
129                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
130                        //set samples
131                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
132                        vec smpi (dim);            // ith sample
133                        vec steps =zeros(dim);            // ith sample
134                        // first corner
135                        for (int j = 0; j < dim; j++) {
136                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
137                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
[392]138                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
[388]139                        }
140                        // fill samples
141                        for (int i = 0; i < Npoints; i++) {
142                                // copy
143                                samples(i) = smpi; 
144                                // go through all dimensions
145                                for (int j = 0;j < dim;j++) {
146                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
[395]147                                                ind (j) = 0; //shift back
[388]148                                                smpi(j) = XYZ(j)(0);
149                                               
[395]150                                                if (i<Npoints-1) {
151                                                        ind (j + 1) ++; //increase the next dimension;
152                                                        smpi(j+1) += steps(j+1);
153                                                        break;
154                                                }
[388]155                                               
156                                        } else {
[395]157                                                ind (j) ++; 
[388]158                                                smpi(j) +=steps(j);
159                                                break;
160                                        }
161                                }
162                        }
163                }
[384]164                //! set debug file
[388]165                void set_debug_file (const string fname) {
166                        if (DBG) delete dbg_file;
167                        dbg_file = new it_file (fname);
[384]168                        if (dbg_file) DBG = true;
169                }
170                //! Set internal parameters used in approximation
[399]171                void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
[388]172                        METHOD = MTH;
[384]173                        beta = beta0;
174                }
175                //! Set support points from a pdf by drawing N samples
[388]176                void set_support (const epdf &overall, int N) {
177                        eSmp.set_statistics (&overall, N);
178                        Npoints = N;
[384]179                }
[388]180
[384]181                //! Destructor
182                virtual ~merger_base() {
183                        for (int i = 0; i < Nsources; i++) {
184                                delete dls (i);
185                                delete zdls (i);
[299]186                        }
[384]187                        if (DBG) delete dbg_file;
188                };
189                //!@}
[388]190
[384]191                //! \name Mathematical operations
192                //!@{
[388]193
[384]194                //!Merge given sources in given points
[395]195                virtual void merge () {
[388]196                        validate();
197
[384]198                        //check if sources overlap:
199                        bool OK = true;
200                        for (int i = 0;i < mpdfs.length(); i++) {
201                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
202                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
[310]203                        }
[384]204
205                        if (OK) {
206                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
207
208                                vec emptyvec (0);
209                                for (int i = 0; i < mpdfs.length(); i++) {
210                                        for (int j = 0; j < eSmp._w().length(); j++) {
211                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
[299]212                                        }
213                                }
[384]214
[395]215                                vec w_nn=merge_points (lW);
216                                vec wtmp = exp (w_nn-max(w_nn));
[384]217                                //renormalize
218                                eSmp._w() = wtmp / sum (wtmp);
219                        } else {
[388]220                                it_error ("Sources are not compatible - use merger_mix");
[299]221                        }
[384]222                };
[176]223
[384]224
225                //! Merge log-likelihood values in points using method specified by parameter METHOD
226                vec merge_points (mat &lW);
[388]227
228
[384]229                //! sample from merged density
[192]230//! weight w is a
[384]231                vec mean() const {
232                        const Vec<double> &w = eSmp._w();
233                        const Array<vec> &S = eSmp._samples();
234                        vec tmp = zeros (dim);
235                        for (int i = 0; i < Npoints; i++) {
236                                tmp += w (i) * S (i);
[299]237                        }
[384]238                        return tmp;
239                }
240                mat covariance() const {
241                        const vec &w = eSmp._w();
242                        const Array<vec> &S = eSmp._samples();
[299]243
[384]244                        vec mea = mean();
[299]245
[404]246//                      cout << sum (w) << "," << w*w << endl;
[299]247
[384]248                        mat Tmp = zeros (dim, dim);
249                        for (int i = 0; i < Npoints; i++) {
250                                Tmp += w (i) * outer_product (S (i), S (i));
[299]251                        }
[384]252                        return Tmp -outer_product (mea, mea);
253                }
254                vec variance() const {
255                        const vec &w = eSmp._w();
256                        const Array<vec> &S = eSmp._samples();
[299]257
[384]258                        vec tmp = zeros (dim);
259                        for (int i = 0; i < Nsources; i++) {
260                                tmp += w (i) * pow (S (i), 2);
[299]261                        }
[384]262                        return tmp -pow (mean(), 2);
263                }
264                //!@}
[192]265
[384]266                //! \name Access to attributes
267                //! @{
268
269                //! Access function
270                eEmp& _Smp() {return eSmp;}
[388]271
272                //! load from setting
273                void from_setting (const Setting& set) {
274                        // get support
275                        // find which method to use
276                        string meth_str;
277                        UI::get<string> (meth_str, set, "method");
278                        if (!strcmp (meth_str.c_str(), "arithmetic"))
279                                set_method (ARITHMETIC);
280                        else {
281                                if (!strcmp (meth_str.c_str(), "geometric"))
282                                        set_method (GEOMETRIC);
283                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
[392]284                                        set_method (LOGNORMAL);
[388]285                                        set.lookupValue( "beta",beta);
286                                }
287                        }
[395]288                        if (set.exists("dbg_file")){ 
289                                string dbg_file;
290                                UI::get<string> (dbg_file, set, "dbg_file");
291                                set_debug_file(dbg_file);
292                        }
293                        //validate() - not used
[388]294                }
295
296                void validate() {
297                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
298                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
299                        it_assert (isnamed(),"mergers must be named");
300                }
[384]301                //!@}
302};
[388]303UIREGISTER(merger_base);
[384]304
305class merger_mix : public merger_base
306{
307        protected:
308                //!Internal mixture of EF models
309                MixEF Mix;
310                //!Number of components in a mixture
311                int Ncoms;
[395]312                //! coefficient of resampling [0,1]
[384]313                double effss_coef;
[395]314                //! stop after niter iterations
315                int stop_niter;
[399]316               
317                //! default value for Ncoms
318                static const int DFLT_Ncoms;
319                //! default value for efss_coef;
320                static const double DFLT_effss_coef;
[384]321
322        public:
323                //!\name Constructors
324                //!@{
325                merger_mix () {};
[388]326                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
[384]327                //! Set sources and prepare all internal structures
[388]328                void set_sources (const Array<mpdf*> &S, bool own) {
329                        merger_base::set_sources (S,own);
[384]330                        Nsources = S.length();
331                }
332                //! Set internal parameters used in approximation
[399]333                void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
[384]334                        Ncoms = Ncoms0;
335                        effss_coef = effss_coef0;
336                }
337                //!@}
[388]338
[384]339                //! \name Mathematical operations
340                //!@{
[388]341
[384]342                //!Merge values using mixture approximation
343                void merge ();
344
345                //! sample from the approximating mixture
346                vec sample () const { return Mix.posterior().sample();}
347                //! loglikelihood computed on mixture models
348                double evallog (const vec &dt) const {
349                        vec dtf = ones (dt.length() + 1);
350                        dtf.set_subvector (0, dt);
351                        return Mix.logpred (dtf);
352                }
353                //!@}
354
355                //!\name Access functions
356                //!@{
[192]357//! Access function
[384]358                MixEF& _Mix() {return Mix;}
359                //! Access function
360                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
[388]361                //! from_settings
[395]362                void from_setting(const Setting& set){
[388]363                        merger_base::from_setting(set);
364                        set.lookupValue("ncoms",Ncoms);
[395]365                        set.lookupValue("effss_coef",effss_coef);
366                        set.lookupValue("stop_niter",stop_niter);
[388]367                }
[384]368               
[388]369                //! @}
370
[384]371};
[388]372UIREGISTER(merger_mix);
[176]373
[254]374}
[176]375
376#endif // MER_H
Note: See TracBrowser for help on using the browser.