root/library/bdm/stat/merger.h @ 471

Revision 471, 10.1 kB (checked in by mido, 15 years ago)

1) ad UserInfo?: UI::get a UI::build predelany tak, ze vraci fals / NULL v pripade neexistence pozadovaneho Settingu, pridana enumericky typ UI::SettingPresence?, predelany stavajici UI implementace, dodelana UI dokumentace
2) dokoncena konfigurace ASTYLERU, brzy bude aplikovan
3) doxygen nastaven tak, ze vytvari soubor doxy_warnings.txt

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MERGER_H
14#define MERGER_H
15
16
17#include "../estim/mixtures.h"
18
19namespace bdm
20{
21using std::string;
22
23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
25
26/*!
27@brief Base class for general combination of pdfs on discrete support
28
29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
30
31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
35
36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
40
41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
61                //! Default for METHOD
62                static const MERGER_METHOD DFLT_METHOD;
63               
64                //!Prior on the log-normal merging model
65                double beta;
66                //! default for beta
67                static const double DFLT_beta;
68               
69                //! Projection to empirical density (could also be piece-wise linear)
70                eEmp eSmp;
71
72                //! debug or not debug
73                bool DBG;
74
75                //! debugging file
76                it_file* dbg_file;
77        public:
78                //! \name Constructors
79                //! @{
80
81                //!Empty constructor
82                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;}
83
84                //!Constructor from sources
85                merger_base (const Array<mpdf*> &S, bool own=false);
86
87                //! Function setting the main internal structures
88                void set_sources (const Array<mpdf*> &Sources, bool own) {
89                        compositepdf::set_elements (Sources,own);
90                        Nsources=mpdfs.length();
91                        //set sizes
92                        dls.set_size (Sources.length());
93                        rvzs.set_size (Sources.length());
94                        zdls.set_size (Sources.length());
95
96                        rv = getrv (/* checkoverlap = */ false);
97                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
98                        // join rv and rvc - see descriprion
99                        rv.add (rvc);
100                        // get dimension
101                        dim = rv._dsize();
102
103                        // create links between sources and common rv
104                        RV xytmp;
105                        for (int i = 0;i < mpdfs.length();i++) {
106                                //Establich connection between mpdfs and merger
107                                dls (i) = new datalink_m2e;
108                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
109
110                                // find out what is missing in each mpdf
111                                xytmp = mpdfs (i)->_rv();
112                                xytmp.add (mpdfs (i)->_rvc());
113                                // z_i = common_rv-xy
114                                rvzs (i) = rv.subt (xytmp);
115                                //establish connection between extension (z_i|x,y)s and common rv
116                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
117                        };
118                }
119                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
120                void set_support (const Array<vec> &XYZ, const int dimsize) {
121                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
122                }
123                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
124                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
125                        int dim = XYZ.length();  //check with internal dim!!
126                        Npoints = prod (gridsize); 
127                        eSmp.set_parameters (Npoints, false);
128                        Array<vec> &samples = eSmp._samples();
129                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
130                        //set samples
131                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
132                        vec smpi (dim);            // ith sample
133                        vec steps =zeros(dim);            // ith sample
134                        // first corner
135                        for (int j = 0; j < dim; j++) {
136                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
137                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
138                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
139                        }
140                        // fill samples
141                        for (int i = 0; i < Npoints; i++) {
142                                // copy
143                                samples(i) = smpi; 
144                                // go through all dimensions
145                                for (int j = 0;j < dim;j++) {
146                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
147                                                ind (j) = 0; //shift back
148                                                smpi(j) = XYZ(j)(0);
149                                               
150                                                if (i<Npoints-1) {
151                                                        ind (j + 1) ++; //increase the next dimension;
152                                                        smpi(j+1) += steps(j+1);
153                                                        break;
154                                                }
155                                               
156                                        } else {
157                                                ind (j) ++; 
158                                                smpi(j) +=steps(j);
159                                                break;
160                                        }
161                                }
162                        }
163                }
164                //! set debug file
165                void set_debug_file (const string fname) {
166                        if (DBG) delete dbg_file;
167                        dbg_file = new it_file (fname);
168                        if (dbg_file) DBG = true;
169                }
170                //! Set internal parameters used in approximation
171                void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
172                        METHOD = MTH;
173                        beta = beta0;
174                }
175                //! Set support points from a pdf by drawing N samples
176                void set_support (const epdf &overall, int N) {
177                        eSmp.set_statistics (&overall, N);
178                        Npoints = N;
179                }
180
181                //! Destructor
182                virtual ~merger_base() {
183                        for (int i = 0; i < Nsources; i++) {
184                                delete dls (i);
185                                delete zdls (i);
186                        }
187                        if (DBG) delete dbg_file;
188                };
189                //!@}
190
191                //! \name Mathematical operations
192                //!@{
193
194                //!Merge given sources in given points
195                virtual void merge () {
196                        validate();
197
198                        //check if sources overlap:
199                        bool OK = true;
200                        for (int i = 0;i < mpdfs.length(); i++) {
201                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
202                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
203                        }
204
205                        if (OK) {
206                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
207
208                                vec emptyvec (0);
209                                for (int i = 0; i < mpdfs.length(); i++) {
210                                        for (int j = 0; j < eSmp._w().length(); j++) {
211                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
212                                        }
213                                }
214
215                                vec w_nn=merge_points (lW);
216                                vec wtmp = exp (w_nn-max(w_nn));
217                                //renormalize
218                                eSmp._w() = wtmp / sum (wtmp);
219                        } else {
220                                it_error ("Sources are not compatible - use merger_mix");
221                        }
222                };
223
224
225                //! Merge log-likelihood values in points using method specified by parameter METHOD
226                vec merge_points (mat &lW);
227
228
229                //! sample from merged density
230//! weight w is a
231                vec mean() const {
232                        const Vec<double> &w = eSmp._w();
233                        const Array<vec> &S = eSmp._samples();
234                        vec tmp = zeros (dim);
235                        for (int i = 0; i < Npoints; i++) {
236                                tmp += w (i) * S (i);
237                        }
238                        return tmp;
239                }
240                mat covariance() const {
241                        const vec &w = eSmp._w();
242                        const Array<vec> &S = eSmp._samples();
243
244                        vec mea = mean();
245
246//                      cout << sum (w) << "," << w*w << endl;
247
248                        mat Tmp = zeros (dim, dim);
249                        for (int i = 0; i < Npoints; i++) {
250                                Tmp += w (i) * outer_product (S (i), S (i));
251                        }
252                        return Tmp -outer_product (mea, mea);
253                }
254                vec variance() const {
255                        const vec &w = eSmp._w();
256                        const Array<vec> &S = eSmp._samples();
257
258                        vec tmp = zeros (dim);
259                        for (int i = 0; i < Nsources; i++) {
260                                tmp += w (i) * pow (S (i), 2);
261                        }
262                        return tmp -pow (mean(), 2);
263                }
264                //!@}
265
266                //! \name Access to attributes
267                //! @{
268
269                //! Access function
270                eEmp& _Smp() {return eSmp;}
271
272                //! load from setting
273                void from_setting (const Setting& set) {
274                        // get support
275                        // find which method to use
276                        string meth_str;
277                        UI::get<string> (meth_str, set, "method", UI::compulsory);
278                        if (!strcmp (meth_str.c_str(), "arithmetic"))
279                                set_method (ARITHMETIC);
280                        else {
281                                if (!strcmp (meth_str.c_str(), "geometric"))
282                                        set_method (GEOMETRIC);
283                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
284                                        set_method (LOGNORMAL);
285                                        set.lookupValue( "beta",beta);
286                                }
287                        }
288                        string dbg_file;
289                        if (UI::get(dbg_file, set, "dbg_file"))
290                                set_debug_file(dbg_file);
291                        //validate() - not used
292                }
293
294                void validate() {
295                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
296                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
297                        it_assert (isnamed(),"mergers must be named");
298                }
299                //!@}
300};
301UIREGISTER(merger_base);
302
303class merger_mix : public merger_base
304{
305        protected:
306                //!Internal mixture of EF models
307                MixEF Mix;
308                //!Number of components in a mixture
309                int Ncoms;
310                //! coefficient of resampling [0,1]
311                double effss_coef;
312                //! stop after niter iterations
313                int stop_niter;
314               
315                //! default value for Ncoms
316                static const int DFLT_Ncoms;
317                //! default value for efss_coef;
318                static const double DFLT_effss_coef;
319
320        public:
321                //!\name Constructors
322                //!@{
323                merger_mix () {};
324                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
325                //! Set sources and prepare all internal structures
326                void set_sources (const Array<mpdf*> &S, bool own) {
327                        merger_base::set_sources (S,own);
328                        Nsources = S.length();
329                }
330                //! Set internal parameters used in approximation
331                void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
332                        Ncoms = Ncoms0;
333                        effss_coef = effss_coef0;
334                }
335                //!@}
336
337                //! \name Mathematical operations
338                //!@{
339
340                //!Merge values using mixture approximation
341                void merge ();
342
343                //! sample from the approximating mixture
344                vec sample () const { return Mix.posterior().sample();}
345                //! loglikelihood computed on mixture models
346                double evallog (const vec &dt) const {
347                        vec dtf = ones (dt.length() + 1);
348                        dtf.set_subvector (0, dt);
349                        return Mix.logpred (dtf);
350                }
351                //!@}
352
353                //!\name Access functions
354                //!@{
355//! Access function
356                MixEF& _Mix() {return Mix;}
357                //! Access function
358                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
359                //! from_settings
360                void from_setting(const Setting& set){
361                        merger_base::from_setting(set);
362                        set.lookupValue("ncoms",Ncoms);
363                        set.lookupValue("effss_coef",effss_coef);
364                        set.lookupValue("stop_niter",stop_niter);
365                }
366               
367                //! @}
368
369};
370UIREGISTER(merger_mix);
371
372}
373
374#endif // MER_H
Note: See TracBrowser for help on using the browser.