root/library/bdm/stat/merger.h @ 395

Revision 395, 9.9 kB (checked in by smidl, 15 years ago)

merging works for merger_mx

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //!Prior on the log-normal merging model
62                double beta;
63
64                //! Projection to empirical density (could also be piece-wise linear)
65                eEmp eSmp;
66
67                //! debug or not debug
68                bool DBG;
69
70                //! debugging file
71                it_file* dbg_file;
72        public:
73                //! \name Constructors
74                //! @{
75
76                //!Empty constructor
77                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
78                //!Constructor from sources
79                merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
80                //! Function setting the main internal structures
81                void set_sources (const Array<mpdf*> &Sources, bool own) {
82                        compositepdf::set_elements (Sources,own);
83                        Nsources=mpdfs.length();
84                        //set sizes
85                        dls.set_size (Sources.length());
86                        rvzs.set_size (Sources.length());
87                        zdls.set_size (Sources.length());
88
89                        rv = getrv (/* checkoverlap = */ false);
90                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
91                        // join rv and rvc - see descriprion
92                        rv.add (rvc);
93                        // get dimension
94                        dim = rv._dsize();
95
96                        // create links between sources and common rv
97                        RV xytmp;
98                        for (int i = 0;i < mpdfs.length();i++) {
99                                //Establich connection between mpdfs and merger
100                                dls (i) = new datalink_m2e;
101                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
102
103                                // find out what is missing in each mpdf
104                                xytmp = mpdfs (i)->_rv();
105                                xytmp.add (mpdfs (i)->_rvc());
106                                // z_i = common_rv-xy
107                                rvzs (i) = rv.subt (xytmp);
108                                //establish connection between extension (z_i|x,y)s and common rv
109                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
110                        };
111                }
112                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
113                void set_support (const Array<vec> &XYZ, const int dimsize) {
114                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
115                }
116                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
117                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
118                        int dim = XYZ.length();  //check with internal dim!!
119                        Npoints = prod (gridsize); 
120                        eSmp.set_parameters (Npoints, false);
121                        Array<vec> &samples = eSmp._samples();
122                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
123                        //set samples
124                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
125                        vec smpi (dim);            // ith sample
126                        vec steps =zeros(dim);            // ith sample
127                        // first corner
128                        for (int j = 0; j < dim; j++) {
129                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
130                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
131                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
132                        }
133                        // fill samples
134                        for (int i = 0; i < Npoints; i++) {
135                                // copy
136                                samples(i) = smpi; 
137                                // go through all dimensions
138                                for (int j = 0;j < dim;j++) {
139                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
140                                                ind (j) = 0; //shift back
141                                                smpi(j) = XYZ(j)(0);
142                                               
143                                                if (i<Npoints-1) {
144                                                        ind (j + 1) ++; //increase the next dimension;
145                                                        smpi(j+1) += steps(j+1);
146                                                        break;
147                                                }
148                                               
149                                        } else {
150                                                ind (j) ++; 
151                                                smpi(j) +=steps(j);
152                                                break;
153                                        }
154                                }
155                        }
156                }
157                //! set debug file
158                void set_debug_file (const string fname) {
159                        if (DBG) delete dbg_file;
160                        dbg_file = new it_file (fname);
161                        if (dbg_file) DBG = true;
162                }
163                //! Set internal parameters used in approximation
164                void set_method (MERGER_METHOD MTH, double beta0 = 0.0) {
165                        METHOD = MTH;
166                        beta = beta0;
167                }
168                //! Set support points from a pdf by drawing N samples
169                void set_support (const epdf &overall, int N) {
170                        eSmp.set_statistics (&overall, N);
171                        Npoints = N;
172                }
173
174                //! Destructor
175                virtual ~merger_base() {
176                        for (int i = 0; i < Nsources; i++) {
177                                delete dls (i);
178                                delete zdls (i);
179                        }
180                        if (DBG) delete dbg_file;
181                };
182                //!@}
183
184                //! \name Mathematical operations
185                //!@{
186
187                //!Merge given sources in given points
188                virtual void merge () {
189                        validate();
190
191                        //check if sources overlap:
192                        bool OK = true;
193                        for (int i = 0;i < mpdfs.length(); i++) {
194                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
195                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
196                        }
197
198                        if (OK) {
199                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
200
201                                vec emptyvec (0);
202                                for (int i = 0; i < mpdfs.length(); i++) {
203                                        for (int j = 0; j < eSmp._w().length(); j++) {
204                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
205                                        }
206                                }
207
208                                vec w_nn=merge_points (lW);
209                                vec wtmp = exp (w_nn-max(w_nn));
210                                //renormalize
211                                eSmp._w() = wtmp / sum (wtmp);
212                        } else {
213                                it_error ("Sources are not compatible - use merger_mix");
214                        }
215                };
216
217
218                //! Merge log-likelihood values in points using method specified by parameter METHOD
219                vec merge_points (mat &lW);
220
221
222                //! sample from merged density
223//! weight w is a
224                vec mean() const {
225                        const Vec<double> &w = eSmp._w();
226                        const Array<vec> &S = eSmp._samples();
227                        vec tmp = zeros (dim);
228                        for (int i = 0; i < Npoints; i++) {
229                                tmp += w (i) * S (i);
230                        }
231                        return tmp;
232                }
233                mat covariance() const {
234                        const vec &w = eSmp._w();
235                        const Array<vec> &S = eSmp._samples();
236
237                        vec mea = mean();
238
239                        cout << sum (w) << "," << w*w << endl;
240
241                        mat Tmp = zeros (dim, dim);
242                        for (int i = 0; i < Npoints; i++) {
243                                Tmp += w (i) * outer_product (S (i), S (i));
244                        }
245                        return Tmp -outer_product (mea, mea);
246                }
247                vec variance() const {
248                        const vec &w = eSmp._w();
249                        const Array<vec> &S = eSmp._samples();
250
251                        vec tmp = zeros (dim);
252                        for (int i = 0; i < Nsources; i++) {
253                                tmp += w (i) * pow (S (i), 2);
254                        }
255                        return tmp -pow (mean(), 2);
256                }
257                //!@}
258
259                //! \name Access to attributes
260                //! @{
261
262                //! Access function
263                eEmp& _Smp() {return eSmp;}
264
265                //! load from setting
266                void from_setting (const Setting& set) {
267                        // get support
268                        // find which method to use
269                        string meth_str;
270                        UI::get<string> (meth_str, set, "method");
271                        if (!strcmp (meth_str.c_str(), "arithmetic"))
272                                set_method (ARITHMETIC);
273                        else {
274                                if (!strcmp (meth_str.c_str(), "geometric"))
275                                        set_method (GEOMETRIC);
276                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
277                                        set_method (LOGNORMAL);
278                                        set.lookupValue( "beta",beta);
279                                }
280                        }
281                        if (set.exists("dbg_file")){ 
282                                string dbg_file;
283                                UI::get<string> (dbg_file, set, "dbg_file");
284                                set_debug_file(dbg_file);
285                        }
286                        //validate() - not used
287                }
288
289                void validate() {
290                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
291                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
292                        it_assert (isnamed(),"mergers must be named");
293                }
294                //!@}
295};
296UIREGISTER(merger_base);
297
298class merger_mix : public merger_base
299{
300        protected:
301                //!Internal mixture of EF models
302                MixEF Mix;
303                //!Number of components in a mixture
304                int Ncoms;
305                //! coefficient of resampling [0,1]
306                double effss_coef;
307                //! stop after niter iterations
308                int stop_niter;
309
310        public:
311                //!\name Constructors
312                //!@{
313                merger_mix () {};
314                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
315                //! Set sources and prepare all internal structures
316                void set_sources (const Array<mpdf*> &S, bool own) {
317                        merger_base::set_sources (S,own);
318                        Nsources = S.length();
319                }
320                //! Set internal parameters used in approximation
321                void set_parameters (int Ncoms0 = 10, double effss_coef0 = 0.5) {
322                        Ncoms = Ncoms0;
323                        effss_coef = effss_coef0;
324                }
325                //!@}
326
327                //! \name Mathematical operations
328                //!@{
329
330                //!Merge values using mixture approximation
331                void merge ();
332
333                //! sample from the approximating mixture
334                vec sample () const { return Mix.posterior().sample();}
335                //! loglikelihood computed on mixture models
336                double evallog (const vec &dt) const {
337                        vec dtf = ones (dt.length() + 1);
338                        dtf.set_subvector (0, dt);
339                        return Mix.logpred (dtf);
340                }
341                //!@}
342
343                //!\name Access functions
344                //!@{
345//! Access function
346                MixEF& _Mix() {return Mix;}
347                //! Access function
348                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
349                //! from_settings
350                void from_setting(const Setting& set){
351                        merger_base::from_setting(set);
352                        set.lookupValue("ncoms",Ncoms);
353                        set.lookupValue("effss_coef",effss_coef);
354                        set.lookupValue("stop_niter",stop_niter);
355                }
356               
357                //! @}
358
359};
360UIREGISTER(merger_mix);
361
362}
363
364#endif // MER_H
Note: See TracBrowser for help on using the browser.