root/library/bdm/stat/merger.h @ 388

Revision 388, 9.6 kB (checked in by smidl, 15 years ago)

mergers

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //!Prior on the log-normal merging model
62                double beta;
63
64                //! Projection to empirical density (could also be piece-wise linear)
65                eEmp eSmp;
66
67                //! debug or not debug
68                bool DBG;
69
70                //! debugging file
71                it_file* dbg_file;
72        public:
73                //! \name Constructors
74                //! @{
75
76                //!Empty constructor
77                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
78                //!Constructor from sources
79                merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
80                //! Function setting the main internal structures
81                void set_sources (const Array<mpdf*> &Sources, bool own) {
82                        compositepdf::set_elements (Sources,own);
83                        //set sizes
84                        dls.set_size (Sources.length());
85                        rvzs.set_size (Sources.length());
86                        zdls.set_size (Sources.length());
87
88                        rv = getrv (/* checkoverlap = */ false);
89                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
90                        // join rv and rvc - see descriprion
91                        rv.add (rvc);
92                        // get dimension
93                        dim = rv._dsize();
94
95                        // create links between sources and common rv
96                        RV xytmp;
97                        for (int i = 0;i < mpdfs.length();i++) {
98                                //Establich connection between mpdfs and merger
99                                dls (i) = new datalink_m2e;
100                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
101
102                                // find out what is missing in each mpdf
103                                xytmp = mpdfs (i)->_rv();
104                                xytmp.add (mpdfs (i)->_rvc());
105                                // z_i = common_rv-xy
106                                rvzs (i) = rv.subt (xytmp);
107                                //establish connection between extension (z_i|x,y)s and common rv
108                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
109                        };
110                }
111                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
112                void set_support (const Array<vec> &XYZ, const int dimsize) {
113                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
114                }
115                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
116                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
117                        int dim = XYZ.length();  //check with internal dim!!
118                        Npoints = prod (gridsize); 
119                        eSmp.set_parameters (Npoints, false);
120                        Array<vec> &samples = eSmp._samples();
121                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
122                        //set samples
123                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
124                        vec smpi (dim);            // ith sample
125                        vec steps =zeros(dim);            // ith sample
126                        // first corner
127                        for (int j = 0; j < dim; j++) {
128                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
129                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
130                                steps (j) = (smpi(j) - XYZ(j)(1))/gridsize(j);
131                        }
132                        // fill samples
133                        int act_dim=0; //active dimension
134                        for (int i = 0; i < Npoints; i++) {
135                                // copy
136                                samples(i) = smpi; 
137                                // go through all dimensions
138                                for (int j = 0;j < dim;j++) {
139                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
140                                                ind (j) = 0; //shift back
141                                                smpi(j) = XYZ(j)(0);
142                                               
143                                                ind (j + 1) ++; //increase the next dimension;
144                                                smpi(j+1) += steps(j+1);
145                                               
146                                                if (ind (j + 1) < gridsize (j + 1) - 1) break;
147                                        } else {
148                                                ind (j) ++; 
149                                                smpi(j) +=steps(j);
150                                                break;
151                                        }
152                                }
153                        }
154                }
155                //! set debug file
156                void set_debug_file (const string fname) {
157                        if (DBG) delete dbg_file;
158                        dbg_file = new it_file (fname);
159                        if (dbg_file) DBG = true;
160                }
161                //! Set internal parameters used in approximation
162                void set_method (MERGER_METHOD MTH, double beta0 = 0.0) {
163                        METHOD = MTH;
164                        beta = beta0;
165                }
166                //! Set support points from a pdf by drawing N samples
167                void set_support (const epdf &overall, int N) {
168                        eSmp.set_statistics (&overall, N);
169                        Npoints = N;
170                }
171
172                //! Destructor
173                virtual ~merger_base() {
174                        for (int i = 0; i < Nsources; i++) {
175                                delete dls (i);
176                                delete zdls (i);
177                        }
178                        if (DBG) delete dbg_file;
179                };
180                //!@}
181
182                //! \name Mathematical operations
183                //!@{
184
185                //!Merge given sources in given points
186                void merge () {
187                        validate();
188
189                        //check if sources overlap:
190                        bool OK = true;
191                        for (int i = 0;i < mpdfs.length(); i++) {
192                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
193                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
194                        }
195
196                        if (OK) {
197                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
198
199                                vec emptyvec (0);
200                                for (int i = 0; i < mpdfs.length(); i++) {
201                                        for (int j = 0; j < eSmp._w().length(); j++) {
202                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
203                                        }
204                                }
205
206                                vec wtmp = exp (merge_points (lW));
207                                //renormalize
208                                eSmp._w() = wtmp / sum (wtmp);
209                        } else {
210                                it_error ("Sources are not compatible - use merger_mix");
211                        }
212                };
213
214
215                //! Merge log-likelihood values in points using method specified by parameter METHOD
216                vec merge_points (mat &lW);
217
218
219                //! sample from merged density
220//! weight w is a
221                vec mean() const {
222                        const Vec<double> &w = eSmp._w();
223                        const Array<vec> &S = eSmp._samples();
224                        vec tmp = zeros (dim);
225                        for (int i = 0; i < Npoints; i++) {
226                                tmp += w (i) * S (i);
227                        }
228                        return tmp;
229                }
230                mat covariance() const {
231                        const vec &w = eSmp._w();
232                        const Array<vec> &S = eSmp._samples();
233
234                        vec mea = mean();
235
236                        cout << sum (w) << "," << w*w << endl;
237
238                        mat Tmp = zeros (dim, dim);
239                        for (int i = 0; i < Npoints; i++) {
240                                Tmp += w (i) * outer_product (S (i), S (i));
241                        }
242                        return Tmp -outer_product (mea, mea);
243                }
244                vec variance() const {
245                        const vec &w = eSmp._w();
246                        const Array<vec> &S = eSmp._samples();
247
248                        vec tmp = zeros (dim);
249                        for (int i = 0; i < Nsources; i++) {
250                                tmp += w (i) * pow (S (i), 2);
251                        }
252                        return tmp -pow (mean(), 2);
253                }
254                //!@}
255
256                //! \name Access to attributes
257                //! @{
258
259                //! Access function
260                eEmp& _Smp() {return eSmp;}
261
262                //! load from setting
263                void from_setting (const Setting& set) {
264                        // get support
265                        // find which method to use
266                        string meth_str;
267                        UI::get<string> (meth_str, set, "method");
268                        if (!strcmp (meth_str.c_str(), "arithmetic"))
269                                set_method (ARITHMETIC);
270                        else {
271                                if (!strcmp (meth_str.c_str(), "geometric"))
272                                        set_method (GEOMETRIC);
273                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
274                                        set_method (GEOMETRIC);
275                                        set.lookupValue( "beta",beta);
276                                }
277                        }
278                        validate();
279                }
280
281                void validate() {
282                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
283                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
284                        it_assert (isnamed(),"mergers must be named");
285                }
286                //!@}
287};
288UIREGISTER(merger_base);
289
290class merger_mix : public merger_base
291{
292        protected:
293                //!Internal mixture of EF models
294                MixEF Mix;
295                //!Number of components in a mixture
296                int Ncoms;
297                //! coefficient of resampling
298                double effss_coef;
299
300        public:
301                //!\name Constructors
302                //!@{
303                merger_mix () {};
304                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
305                //! Set sources and prepare all internal structures
306                void set_sources (const Array<mpdf*> &S, bool own) {
307                        merger_base::set_sources (S,own);
308                        Nsources = S.length();
309                }
310                //! Set internal parameters used in approximation
311                void set_parameters (int Ncoms0 = 10, double effss_coef0 = 0.5) {
312                        Ncoms = Ncoms0;
313                        effss_coef = effss_coef0;
314                }
315                //!@}
316
317                //! \name Mathematical operations
318                //!@{
319
320                //!Merge values using mixture approximation
321                void merge ();
322
323                //! sample from the approximating mixture
324                vec sample () const { return Mix.posterior().sample();}
325                //! loglikelihood computed on mixture models
326                double evallog (const vec &dt) const {
327                        vec dtf = ones (dt.length() + 1);
328                        dtf.set_subvector (0, dt);
329                        return Mix.logpred (dtf);
330                }
331                //!@}
332
333                //!\name Access functions
334                //!@{
335//! Access function
336                MixEF& _Mix() {return Mix;}
337                //! Access function
338                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
339                //! from_settings
340                void from_settings(const Setting& set){
341                        merger_base::from_setting(set);
342                        set.lookupValue("ncoms",Ncoms);
343                }
344               
345                //! @}
346
347};
348UIREGISTER(merger_mix);
349
350}
351
352#endif // MER_H
Note: See TracBrowser for help on using the browser.