root/library/bdm/stat/merger.h @ 392

Revision 392, 9.6 kB (checked in by smidl, 15 years ago)

compilation fixes - UI_build use exceptions now!!!!

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //!Prior on the log-normal merging model
62                double beta;
63
64                //! Projection to empirical density (could also be piece-wise linear)
65                eEmp eSmp;
66
67                //! debug or not debug
68                bool DBG;
69
70                //! debugging file
71                it_file* dbg_file;
72        public:
73                //! \name Constructors
74                //! @{
75
76                //!Empty constructor
77                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
78                //!Constructor from sources
79                merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
80                //! Function setting the main internal structures
81                void set_sources (const Array<mpdf*> &Sources, bool own) {
82                        compositepdf::set_elements (Sources,own);
83                        Nsources=mpdfs.length();
84                        //set sizes
85                        dls.set_size (Sources.length());
86                        rvzs.set_size (Sources.length());
87                        zdls.set_size (Sources.length());
88
89                        rv = getrv (/* checkoverlap = */ false);
90                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
91                        // join rv and rvc - see descriprion
92                        rv.add (rvc);
93                        // get dimension
94                        dim = rv._dsize();
95
96                        // create links between sources and common rv
97                        RV xytmp;
98                        for (int i = 0;i < mpdfs.length();i++) {
99                                //Establich connection between mpdfs and merger
100                                dls (i) = new datalink_m2e;
101                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
102
103                                // find out what is missing in each mpdf
104                                xytmp = mpdfs (i)->_rv();
105                                xytmp.add (mpdfs (i)->_rvc());
106                                // z_i = common_rv-xy
107                                rvzs (i) = rv.subt (xytmp);
108                                //establish connection between extension (z_i|x,y)s and common rv
109                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
110                        };
111                }
112                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
113                void set_support (const Array<vec> &XYZ, const int dimsize) {
114                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
115                }
116                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
117                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
118                        int dim = XYZ.length();  //check with internal dim!!
119                        Npoints = prod (gridsize); 
120                        eSmp.set_parameters (Npoints, false);
121                        Array<vec> &samples = eSmp._samples();
122                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
123                        //set samples
124                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
125                        vec smpi (dim);            // ith sample
126                        vec steps =zeros(dim);            // ith sample
127                        // first corner
128                        for (int j = 0; j < dim; j++) {
129                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
130                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
131                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
132                        }
133                        // fill samples
134                        int act_dim=0; //active dimension
135                        for (int i = 0; i < Npoints; i++) {
136                                // copy
137                                samples(i) = smpi; 
138                                // go through all dimensions
139                                for (int j = 0;j < dim;j++) {
140                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
141//                                              ind (j) = 0; //shift back
142                                                smpi(j) = XYZ(j)(0);
143                                               
144//                                              ind (j + 1) ++; //increase the next dimension;
145                                                smpi(j+1) += steps(j+1);
146                                               
147                                                if (ind (j + 1) < gridsize (j + 1) - 1) break;
148                                        } else {
149//                                              ind (j) ++;
150                                                smpi(j) +=steps(j);
151                                                break;
152                                        }
153                                }
154                        }
155                }
156                //! set debug file
157                void set_debug_file (const string fname) {
158                        if (DBG) delete dbg_file;
159                        dbg_file = new it_file (fname);
160                        if (dbg_file) DBG = true;
161                }
162                //! Set internal parameters used in approximation
163                void set_method (MERGER_METHOD MTH, double beta0 = 0.0) {
164                        METHOD = MTH;
165                        beta = beta0;
166                }
167                //! Set support points from a pdf by drawing N samples
168                void set_support (const epdf &overall, int N) {
169                        eSmp.set_statistics (&overall, N);
170                        Npoints = N;
171                }
172
173                //! Destructor
174                virtual ~merger_base() {
175                        for (int i = 0; i < Nsources; i++) {
176                                delete dls (i);
177                                delete zdls (i);
178                        }
179                        if (DBG) delete dbg_file;
180                };
181                //!@}
182
183                //! \name Mathematical operations
184                //!@{
185
186                //!Merge given sources in given points
187                void merge () {
188                        validate();
189
190                        //check if sources overlap:
191                        bool OK = true;
192                        for (int i = 0;i < mpdfs.length(); i++) {
193                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
194                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
195                        }
196
197                        if (OK) {
198                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
199
200                                vec emptyvec (0);
201                                for (int i = 0; i < mpdfs.length(); i++) {
202                                        for (int j = 0; j < eSmp._w().length(); j++) {
203                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
204                                        }
205                                }
206
207                                vec wtmp = exp (merge_points (lW));
208                                //renormalize
209                                eSmp._w() = wtmp / sum (wtmp);
210                        } else {
211                                it_error ("Sources are not compatible - use merger_mix");
212                        }
213                };
214
215
216                //! Merge log-likelihood values in points using method specified by parameter METHOD
217                vec merge_points (mat &lW);
218
219
220                //! sample from merged density
221//! weight w is a
222                vec mean() const {
223                        const Vec<double> &w = eSmp._w();
224                        const Array<vec> &S = eSmp._samples();
225                        vec tmp = zeros (dim);
226                        for (int i = 0; i < Npoints; i++) {
227                                tmp += w (i) * S (i);
228                        }
229                        return tmp;
230                }
231                mat covariance() const {
232                        const vec &w = eSmp._w();
233                        const Array<vec> &S = eSmp._samples();
234
235                        vec mea = mean();
236
237                        cout << sum (w) << "," << w*w << endl;
238
239                        mat Tmp = zeros (dim, dim);
240                        for (int i = 0; i < Npoints; i++) {
241                                Tmp += w (i) * outer_product (S (i), S (i));
242                        }
243                        return Tmp -outer_product (mea, mea);
244                }
245                vec variance() const {
246                        const vec &w = eSmp._w();
247                        const Array<vec> &S = eSmp._samples();
248
249                        vec tmp = zeros (dim);
250                        for (int i = 0; i < Nsources; i++) {
251                                tmp += w (i) * pow (S (i), 2);
252                        }
253                        return tmp -pow (mean(), 2);
254                }
255                //!@}
256
257                //! \name Access to attributes
258                //! @{
259
260                //! Access function
261                eEmp& _Smp() {return eSmp;}
262
263                //! load from setting
264                void from_setting (const Setting& set) {
265                        // get support
266                        // find which method to use
267                        string meth_str;
268                        UI::get<string> (meth_str, set, "method");
269                        if (!strcmp (meth_str.c_str(), "arithmetic"))
270                                set_method (ARITHMETIC);
271                        else {
272                                if (!strcmp (meth_str.c_str(), "geometric"))
273                                        set_method (GEOMETRIC);
274                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
275                                        set_method (LOGNORMAL);
276                                        set.lookupValue( "beta",beta);
277                                }
278                        }
279                }
280
281                void validate() {
282                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
283                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
284                        it_assert (isnamed(),"mergers must be named");
285                }
286                //!@}
287};
288UIREGISTER(merger_base);
289
290class merger_mix : public merger_base
291{
292        protected:
293                //!Internal mixture of EF models
294                MixEF Mix;
295                //!Number of components in a mixture
296                int Ncoms;
297                //! coefficient of resampling
298                double effss_coef;
299
300        public:
301                //!\name Constructors
302                //!@{
303                merger_mix () {};
304                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
305                //! Set sources and prepare all internal structures
306                void set_sources (const Array<mpdf*> &S, bool own) {
307                        merger_base::set_sources (S,own);
308                        Nsources = S.length();
309                }
310                //! Set internal parameters used in approximation
311                void set_parameters (int Ncoms0 = 10, double effss_coef0 = 0.5) {
312                        Ncoms = Ncoms0;
313                        effss_coef = effss_coef0;
314                }
315                //!@}
316
317                //! \name Mathematical operations
318                //!@{
319
320                //!Merge values using mixture approximation
321                void merge ();
322
323                //! sample from the approximating mixture
324                vec sample () const { return Mix.posterior().sample();}
325                //! loglikelihood computed on mixture models
326                double evallog (const vec &dt) const {
327                        vec dtf = ones (dt.length() + 1);
328                        dtf.set_subvector (0, dt);
329                        return Mix.logpred (dtf);
330                }
331                //!@}
332
333                //!\name Access functions
334                //!@{
335//! Access function
336                MixEF& _Mix() {return Mix;}
337                //! Access function
338                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
339                //! from_settings
340                void from_settings(const Setting& set){
341                        merger_base::from_setting(set);
342                        set.lookupValue("ncoms",Ncoms);
343                }
344               
345                //! @}
346
347};
348UIREGISTER(merger_mix);
349
350}
351
352#endif // MER_H
Note: See TracBrowser for help on using the browser.