root/library/bdm/stat/merger.h @ 404

Revision 404, 10.2 kB (checked in by smidl, 15 years ago)

Change in epdf: evallog returns -inf for points out of support. Merger is aware of it now.

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //! Default for METHOD
62                static const MERGER_METHOD DFLT_METHOD;
63               
64                //!Prior on the log-normal merging model
65                double beta;
66                //! default for beta
67                static const double DFLT_beta;
68               
69                //! Projection to empirical density (could also be piece-wise linear)
70                eEmp eSmp;
71
72                //! debug or not debug
73                bool DBG;
74
75                //! debugging file
76                it_file* dbg_file;
77        public:
78                //! \name Constructors
79                //! @{
80
81                //!Empty constructor
82                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
83                //!Constructor from sources
84                merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
85                //! Function setting the main internal structures
86                void set_sources (const Array<mpdf*> &Sources, bool own) {
87                        compositepdf::set_elements (Sources,own);
88                        Nsources=mpdfs.length();
89                        //set sizes
90                        dls.set_size (Sources.length());
91                        rvzs.set_size (Sources.length());
92                        zdls.set_size (Sources.length());
93
94                        rv = getrv (/* checkoverlap = */ false);
95                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
96                        // join rv and rvc - see descriprion
97                        rv.add (rvc);
98                        // get dimension
99                        dim = rv._dsize();
100
101                        // create links between sources and common rv
102                        RV xytmp;
103                        for (int i = 0;i < mpdfs.length();i++) {
104                                //Establich connection between mpdfs and merger
105                                dls (i) = new datalink_m2e;
106                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
107
108                                // find out what is missing in each mpdf
109                                xytmp = mpdfs (i)->_rv();
110                                xytmp.add (mpdfs (i)->_rvc());
111                                // z_i = common_rv-xy
112                                rvzs (i) = rv.subt (xytmp);
113                                //establish connection between extension (z_i|x,y)s and common rv
114                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
115                        };
116                }
117                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
118                void set_support (const Array<vec> &XYZ, const int dimsize) {
119                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
120                }
121                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
122                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
123                        int dim = XYZ.length();  //check with internal dim!!
124                        Npoints = prod (gridsize); 
125                        eSmp.set_parameters (Npoints, false);
126                        Array<vec> &samples = eSmp._samples();
127                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
128                        //set samples
129                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
130                        vec smpi (dim);            // ith sample
131                        vec steps =zeros(dim);            // ith sample
132                        // first corner
133                        for (int j = 0; j < dim; j++) {
134                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
135                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
136                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
137                        }
138                        // fill samples
139                        for (int i = 0; i < Npoints; i++) {
140                                // copy
141                                samples(i) = smpi; 
142                                // go through all dimensions
143                                for (int j = 0;j < dim;j++) {
144                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
145                                                ind (j) = 0; //shift back
146                                                smpi(j) = XYZ(j)(0);
147                                               
148                                                if (i<Npoints-1) {
149                                                        ind (j + 1) ++; //increase the next dimension;
150                                                        smpi(j+1) += steps(j+1);
151                                                        break;
152                                                }
153                                               
154                                        } else {
155                                                ind (j) ++; 
156                                                smpi(j) +=steps(j);
157                                                break;
158                                        }
159                                }
160                        }
161                }
162                //! set debug file
163                void set_debug_file (const string fname) {
164                        if (DBG) delete dbg_file;
165                        dbg_file = new it_file (fname);
166                        if (dbg_file) DBG = true;
167                }
168                //! Set internal parameters used in approximation
169                void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
170                        METHOD = MTH;
171                        beta = beta0;
172                }
173                //! Set support points from a pdf by drawing N samples
174                void set_support (const epdf &overall, int N) {
175                        eSmp.set_statistics (&overall, N);
176                        Npoints = N;
177                }
178
179                //! Destructor
180                virtual ~merger_base() {
181                        for (int i = 0; i < Nsources; i++) {
182                                delete dls (i);
183                                delete zdls (i);
184                        }
185                        if (DBG) delete dbg_file;
186                };
187                //!@}
188
189                //! \name Mathematical operations
190                //!@{
191
192                //!Merge given sources in given points
193                virtual void merge () {
194                        validate();
195
196                        //check if sources overlap:
197                        bool OK = true;
198                        for (int i = 0;i < mpdfs.length(); i++) {
199                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
200                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
201                        }
202
203                        if (OK) {
204                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
205
206                                vec emptyvec (0);
207                                for (int i = 0; i < mpdfs.length(); i++) {
208                                        for (int j = 0; j < eSmp._w().length(); j++) {
209                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
210                                        }
211                                }
212
213                                vec w_nn=merge_points (lW);
214                                vec wtmp = exp (w_nn-max(w_nn));
215                                //renormalize
216                                eSmp._w() = wtmp / sum (wtmp);
217                        } else {
218                                it_error ("Sources are not compatible - use merger_mix");
219                        }
220                };
221
222
223                //! Merge log-likelihood values in points using method specified by parameter METHOD
224                vec merge_points (mat &lW);
225
226
227                //! sample from merged density
228//! weight w is a
229                vec mean() const {
230                        const Vec<double> &w = eSmp._w();
231                        const Array<vec> &S = eSmp._samples();
232                        vec tmp = zeros (dim);
233                        for (int i = 0; i < Npoints; i++) {
234                                tmp += w (i) * S (i);
235                        }
236                        return tmp;
237                }
238                mat covariance() const {
239                        const vec &w = eSmp._w();
240                        const Array<vec> &S = eSmp._samples();
241
242                        vec mea = mean();
243
244//                      cout << sum (w) << "," << w*w << endl;
245
246                        mat Tmp = zeros (dim, dim);
247                        for (int i = 0; i < Npoints; i++) {
248                                Tmp += w (i) * outer_product (S (i), S (i));
249                        }
250                        return Tmp -outer_product (mea, mea);
251                }
252                vec variance() const {
253                        const vec &w = eSmp._w();
254                        const Array<vec> &S = eSmp._samples();
255
256                        vec tmp = zeros (dim);
257                        for (int i = 0; i < Nsources; i++) {
258                                tmp += w (i) * pow (S (i), 2);
259                        }
260                        return tmp -pow (mean(), 2);
261                }
262                //!@}
263
264                //! \name Access to attributes
265                //! @{
266
267                //! Access function
268                eEmp& _Smp() {return eSmp;}
269
270                //! load from setting
271                void from_setting (const Setting& set) {
272                        // get support
273                        // find which method to use
274                        string meth_str;
275                        UI::get<string> (meth_str, set, "method");
276                        if (!strcmp (meth_str.c_str(), "arithmetic"))
277                                set_method (ARITHMETIC);
278                        else {
279                                if (!strcmp (meth_str.c_str(), "geometric"))
280                                        set_method (GEOMETRIC);
281                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
282                                        set_method (LOGNORMAL);
283                                        set.lookupValue( "beta",beta);
284                                }
285                        }
286                        if (set.exists("dbg_file")){ 
287                                string dbg_file;
288                                UI::get<string> (dbg_file, set, "dbg_file");
289                                set_debug_file(dbg_file);
290                        }
291                        //validate() - not used
292                }
293
294                void validate() {
295                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
296                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
297                        it_assert (isnamed(),"mergers must be named");
298                }
299                //!@}
300};
301UIREGISTER(merger_base);
302
303class merger_mix : public merger_base
304{
305        protected:
306                //!Internal mixture of EF models
307                MixEF Mix;
308                //!Number of components in a mixture
309                int Ncoms;
310                //! coefficient of resampling [0,1]
311                double effss_coef;
312                //! stop after niter iterations
313                int stop_niter;
314               
315                //! default value for Ncoms
316                static const int DFLT_Ncoms;
317                //! default value for efss_coef;
318                static const double DFLT_effss_coef;
319
320        public:
321                //!\name Constructors
322                //!@{
323                merger_mix () {};
324                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
325                //! Set sources and prepare all internal structures
326                void set_sources (const Array<mpdf*> &S, bool own) {
327                        merger_base::set_sources (S,own);
328                        Nsources = S.length();
329                }
330                //! Set internal parameters used in approximation
331                void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
332                        Ncoms = Ncoms0;
333                        effss_coef = effss_coef0;
334                }
335                //!@}
336
337                //! \name Mathematical operations
338                //!@{
339
340                //!Merge values using mixture approximation
341                void merge ();
342
343                //! sample from the approximating mixture
344                vec sample () const { return Mix.posterior().sample();}
345                //! loglikelihood computed on mixture models
346                double evallog (const vec &dt) const {
347                        vec dtf = ones (dt.length() + 1);
348                        dtf.set_subvector (0, dt);
349                        return Mix.logpred (dtf);
350                }
351                //!@}
352
353                //!\name Access functions
354                //!@{
355//! Access function
356                MixEF& _Mix() {return Mix;}
357                //! Access function
358                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
359                //! from_settings
360                void from_setting(const Setting& set){
361                        merger_base::from_setting(set);
362                        set.lookupValue("ncoms",Ncoms);
363                        set.lookupValue("effss_coef",effss_coef);
364                        set.lookupValue("stop_niter",stop_niter);
365                }
366               
367                //! @}
368
369};
370UIREGISTER(merger_mix);
371
372}
373
374#endif // MER_H
Note: See TracBrowser for help on using the browser.