root/library/bdm/stat/merger.h @ 423

Revision 423, 10.1 kB (checked in by vbarta, 15 years ago)

fixed merger_base constructor to initialize debug fields (still not all fields, though...)

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //! Default for METHOD
62                static const MERGER_METHOD DFLT_METHOD;
63               
64                //!Prior on the log-normal merging model
65                double beta;
66                //! default for beta
67                static const double DFLT_beta;
68               
69                //! Projection to empirical density (could also be piece-wise linear)
70                eEmp eSmp;
71
72                //! debug or not debug
73                bool DBG;
74
75                //! debugging file
76                it_file* dbg_file;
77        public:
78                //! \name Constructors
79                //! @{
80
81                //!Empty constructor
82                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;}
83
84                //!Constructor from sources
85                merger_base (const Array<mpdf*> &S, bool own=false);
86
87                //! Function setting the main internal structures
88                void set_sources (const Array<mpdf*> &Sources, bool own) {
89                        compositepdf::set_elements (Sources,own);
90                        Nsources=mpdfs.length();
91                        //set sizes
92                        dls.set_size (Sources.length());
93                        rvzs.set_size (Sources.length());
94                        zdls.set_size (Sources.length());
95
96                        rv = getrv (/* checkoverlap = */ false);
97                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
98                        // join rv and rvc - see descriprion
99                        rv.add (rvc);
100                        // get dimension
101                        dim = rv._dsize();
102
103                        // create links between sources and common rv
104                        RV xytmp;
105                        for (int i = 0;i < mpdfs.length();i++) {
106                                //Establich connection between mpdfs and merger
107                                dls (i) = new datalink_m2e;
108                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
109
110                                // find out what is missing in each mpdf
111                                xytmp = mpdfs (i)->_rv();
112                                xytmp.add (mpdfs (i)->_rvc());
113                                // z_i = common_rv-xy
114                                rvzs (i) = rv.subt (xytmp);
115                                //establish connection between extension (z_i|x,y)s and common rv
116                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
117                        };
118                }
119                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
120                void set_support (const Array<vec> &XYZ, const int dimsize) {
121                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
122                }
123                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
124                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
125                        int dim = XYZ.length();  //check with internal dim!!
126                        Npoints = prod (gridsize); 
127                        eSmp.set_parameters (Npoints, false);
128                        Array<vec> &samples = eSmp._samples();
129                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
130                        //set samples
131                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
132                        vec smpi (dim);            // ith sample
133                        vec steps =zeros(dim);            // ith sample
134                        // first corner
135                        for (int j = 0; j < dim; j++) {
136                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
137                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
138                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
139                        }
140                        // fill samples
141                        for (int i = 0; i < Npoints; i++) {
142                                // copy
143                                samples(i) = smpi; 
144                                // go through all dimensions
145                                for (int j = 0;j < dim;j++) {
146                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
147                                                ind (j) = 0; //shift back
148                                                smpi(j) = XYZ(j)(0);
149                                               
150                                                if (i<Npoints-1) {
151                                                        ind (j + 1) ++; //increase the next dimension;
152                                                        smpi(j+1) += steps(j+1);
153                                                        break;
154                                                }
155                                               
156                                        } else {
157                                                ind (j) ++; 
158                                                smpi(j) +=steps(j);
159                                                break;
160                                        }
161                                }
162                        }
163                }
164                //! set debug file
165                void set_debug_file (const string fname) {
166                        if (DBG) delete dbg_file;
167                        dbg_file = new it_file (fname);
168                        if (dbg_file) DBG = true;
169                }
170                //! Set internal parameters used in approximation
171                void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
172                        METHOD = MTH;
173                        beta = beta0;
174                }
175                //! Set support points from a pdf by drawing N samples
176                void set_support (const epdf &overall, int N) {
177                        eSmp.set_statistics (&overall, N);
178                        Npoints = N;
179                }
180
181                //! Destructor
182                virtual ~merger_base() {
183                        for (int i = 0; i < Nsources; i++) {
184                                delete dls (i);
185                                delete zdls (i);
186                        }
187                        if (DBG) delete dbg_file;
188                };
189                //!@}
190
191                //! \name Mathematical operations
192                //!@{
193
194                //!Merge given sources in given points
195                virtual void merge () {
196                        validate();
197
198                        //check if sources overlap:
199                        bool OK = true;
200                        for (int i = 0;i < mpdfs.length(); i++) {
201                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
202                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
203                        }
204
205                        if (OK) {
206                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
207
208                                vec emptyvec (0);
209                                for (int i = 0; i < mpdfs.length(); i++) {
210                                        for (int j = 0; j < eSmp._w().length(); j++) {
211                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
212                                        }
213                                }
214
215                                vec w_nn=merge_points (lW);
216                                vec wtmp = exp (w_nn-max(w_nn));
217                                //renormalize
218                                eSmp._w() = wtmp / sum (wtmp);
219                        } else {
220                                it_error ("Sources are not compatible - use merger_mix");
221                        }
222                };
223
224
225                //! Merge log-likelihood values in points using method specified by parameter METHOD
226                vec merge_points (mat &lW);
227
228
229                //! sample from merged density
230//! weight w is a
231                vec mean() const {
232                        const Vec<double> &w = eSmp._w();
233                        const Array<vec> &S = eSmp._samples();
234                        vec tmp = zeros (dim);
235                        for (int i = 0; i < Npoints; i++) {
236                                tmp += w (i) * S (i);
237                        }
238                        return tmp;
239                }
240                mat covariance() const {
241                        const vec &w = eSmp._w();
242                        const Array<vec> &S = eSmp._samples();
243
244                        vec mea = mean();
245
246//                      cout << sum (w) << "," << w*w << endl;
247
248                        mat Tmp = zeros (dim, dim);
249                        for (int i = 0; i < Npoints; i++) {
250                                Tmp += w (i) * outer_product (S (i), S (i));
251                        }
252                        return Tmp -outer_product (mea, mea);
253                }
254                vec variance() const {
255                        const vec &w = eSmp._w();
256                        const Array<vec> &S = eSmp._samples();
257
258                        vec tmp = zeros (dim);
259                        for (int i = 0; i < Nsources; i++) {
260                                tmp += w (i) * pow (S (i), 2);
261                        }
262                        return tmp -pow (mean(), 2);
263                }
264                //!@}
265
266                //! \name Access to attributes
267                //! @{
268
269                //! Access function
270                eEmp& _Smp() {return eSmp;}
271
272                //! load from setting
273                void from_setting (const Setting& set) {
274                        // get support
275                        // find which method to use
276                        string meth_str;
277                        UI::get<string> (meth_str, set, "method");
278                        if (!strcmp (meth_str.c_str(), "arithmetic"))
279                                set_method (ARITHMETIC);
280                        else {
281                                if (!strcmp (meth_str.c_str(), "geometric"))
282                                        set_method (GEOMETRIC);
283                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
284                                        set_method (LOGNORMAL);
285                                        set.lookupValue( "beta",beta);
286                                }
287                        }
288                        if (set.exists("dbg_file")){ 
289                                string dbg_file;
290                                UI::get<string> (dbg_file, set, "dbg_file");
291                                set_debug_file(dbg_file);
292                        }
293                        //validate() - not used
294                }
295
296                void validate() {
297                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
298                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
299                        it_assert (isnamed(),"mergers must be named");
300                }
301                //!@}
302};
303UIREGISTER(merger_base);
304
305class merger_mix : public merger_base
306{
307        protected:
308                //!Internal mixture of EF models
309                MixEF Mix;
310                //!Number of components in a mixture
311                int Ncoms;
312                //! coefficient of resampling [0,1]
313                double effss_coef;
314                //! stop after niter iterations
315                int stop_niter;
316               
317                //! default value for Ncoms
318                static const int DFLT_Ncoms;
319                //! default value for efss_coef;
320                static const double DFLT_effss_coef;
321
322        public:
323                //!\name Constructors
324                //!@{
325                merger_mix () {};
326                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
327                //! Set sources and prepare all internal structures
328                void set_sources (const Array<mpdf*> &S, bool own) {
329                        merger_base::set_sources (S,own);
330                        Nsources = S.length();
331                }
332                //! Set internal parameters used in approximation
333                void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
334                        Ncoms = Ncoms0;
335                        effss_coef = effss_coef0;
336                }
337                //!@}
338
339                //! \name Mathematical operations
340                //!@{
341
342                //!Merge values using mixture approximation
343                void merge ();
344
345                //! sample from the approximating mixture
346                vec sample () const { return Mix.posterior().sample();}
347                //! loglikelihood computed on mixture models
348                double evallog (const vec &dt) const {
349                        vec dtf = ones (dt.length() + 1);
350                        dtf.set_subvector (0, dt);
351                        return Mix.logpred (dtf);
352                }
353                //!@}
354
355                //!\name Access functions
356                //!@{
357//! Access function
358                MixEF& _Mix() {return Mix;}
359                //! Access function
360                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
361                //! from_settings
362                void from_setting(const Setting& set){
363                        merger_base::from_setting(set);
364                        set.lookupValue("ncoms",Ncoms);
365                        set.lookupValue("effss_coef",effss_coef);
366                        set.lookupValue("stop_niter",stop_niter);
367                }
368               
369                //! @}
370
371};
372UIREGISTER(merger_mix);
373
374}
375
376#endif // MER_H
Note: See TracBrowser for help on using the browser.