root/library/bdm/stat/merger.h @ 404

Revision 404, 10.2 kB (checked in by smidl, 15 years ago)

Change in epdf: evallog returns -inf for points out of support. Merger is aware of it now.

  • Property svn:eol-style set to native
RevLine 
[176]1/*!
2  \file
3  \brief Mergers for combination of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
[384]13#ifndef MERGER_H
14#define MERGER_H
[176]15
[262]16
[384]17#include "../estim/mixtures.h"
[176]18
[299]19namespace bdm
20{
[384]21using std::string;
[176]22
[384]23//!Merging methods
24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3};
[176]25
[384]26/*!
27@brief Base class for general combination of pdfs on discrete support
[176]28
[384]29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial.
[198]30
[384]31The merged pdfs are expected to be of the form:
32 \f[ f(x_i|y_i),  i=1..n \f]
33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ .
34Note that all variables will be joined.
[205]35
[384]36As a result of this feature, each source must be extended to common support
37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f]
38where \f$ z_i \f$ accumulate variables that were not in the original source.
39These extensions are calculated on-the-fly.
[299]40
[384]41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$.
42For merging of more general cases, use offsprings merger_mix and merger_grid.
43*/
44
45class merger_base : public compositepdf, public epdf
46{
47        protected:
48                //! Data link for each mpdf in mpdfs
49                Array<datalink_m2e*> dls;
50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$
51                Array<RV> rvzs;
52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$
53                Array<datalink_m2e*> zdls;
54                //! number of support points
55                int Npoints;
56                //! number of sources
57                int Nsources;
58
59                //! switch of the methoh used for merging
60                MERGER_METHOD METHOD;
[399]61                //! Default for METHOD
62                static const MERGER_METHOD DFLT_METHOD;
63               
[384]64                //!Prior on the log-normal merging model
65                double beta;
[399]66                //! default for beta
67                static const double DFLT_beta;
68               
[384]69                //! Projection to empirical density (could also be piece-wise linear)
70                eEmp eSmp;
71
72                //! debug or not debug
73                bool DBG;
74
75                //! debugging file
76                it_file* dbg_file;
77        public:
78                //! \name Constructors
79                //! @{
80
81                //!Empty constructor
[388]82                merger_base () : compositepdf() {DBG = false;dbg_file = NULL;};
[384]83                //!Constructor from sources
[388]84                merger_base (const Array<mpdf*> &S, bool own=false) {set_sources (S,own);};
85                //! Function setting the main internal structures
86                void set_sources (const Array<mpdf*> &Sources, bool own) {
87                        compositepdf::set_elements (Sources,own);
[392]88                        Nsources=mpdfs.length();
[384]89                        //set sizes
90                        dls.set_size (Sources.length());
91                        rvzs.set_size (Sources.length());
92                        zdls.set_size (Sources.length());
93
94                        rv = getrv (/* checkoverlap = */ false);
95                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc!
96                        // join rv and rvc - see descriprion
97                        rv.add (rvc);
98                        // get dimension
99                        dim = rv._dsize();
100
101                        // create links between sources and common rv
102                        RV xytmp;
103                        for (int i = 0;i < mpdfs.length();i++) {
104                                //Establich connection between mpdfs and merger
105                                dls (i) = new datalink_m2e;
106                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv);
107
108                                // find out what is missing in each mpdf
109                                xytmp = mpdfs (i)->_rv();
110                                xytmp.add (mpdfs (i)->_rvc());
111                                // z_i = common_rv-xy
112                                rvzs (i) = rv.subt (xytmp);
113                                //establish connection between extension (z_i|x,y)s and common rv
114                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ;
115                        };
116                }
[388]117                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. Same number of points, \c dimsize, in each dimension.
118                void set_support (const Array<vec> &XYZ, const int dimsize) {
119                        set_support(XYZ,dimsize*ones_i(XYZ.length()));
120                }
121                //! Rectangular support  each vector of XYZ specifies (begining-end) interval for each dimension. \c gridsize specifies number of points is each dimension.
122                void set_support (const Array<vec> &XYZ, const ivec &gridsize) {
123                        int dim = XYZ.length();  //check with internal dim!!
124                        Npoints = prod (gridsize); 
125                        eSmp.set_parameters (Npoints, false);
126                        Array<vec> &samples = eSmp._samples();
127                        eSmp._w() = ones (Npoints) / Npoints; //unifrom size of bins
128                        //set samples
129                        ivec ind = zeros_i (dim);      //indeces of dimensions in for cycle;
130                        vec smpi (dim);            // ith sample
131                        vec steps =zeros(dim);            // ith sample
132                        // first corner
133                        for (int j = 0; j < dim; j++) {
134                                smpi (j) = XYZ (j) (0); /* beginning of the interval*/ 
135                                it_assert(gridsize(j)!=0.0,"Zeros in gridsize!");
[392]136                                steps (j) = ( XYZ(j)(1)-smpi(j) )/gridsize(j);
[388]137                        }
138                        // fill samples
139                        for (int i = 0; i < Npoints; i++) {
140                                // copy
141                                samples(i) = smpi; 
142                                // go through all dimensions
143                                for (int j = 0;j < dim;j++) {
144                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full
[395]145                                                ind (j) = 0; //shift back
[388]146                                                smpi(j) = XYZ(j)(0);
147                                               
[395]148                                                if (i<Npoints-1) {
149                                                        ind (j + 1) ++; //increase the next dimension;
150                                                        smpi(j+1) += steps(j+1);
151                                                        break;
152                                                }
[388]153                                               
154                                        } else {
[395]155                                                ind (j) ++; 
[388]156                                                smpi(j) +=steps(j);
157                                                break;
158                                        }
159                                }
160                        }
161                }
[384]162                //! set debug file
[388]163                void set_debug_file (const string fname) {
164                        if (DBG) delete dbg_file;
165                        dbg_file = new it_file (fname);
[384]166                        if (dbg_file) DBG = true;
167                }
168                //! Set internal parameters used in approximation
[399]169                void set_method (MERGER_METHOD MTH=DFLT_METHOD, double beta0 = DFLT_beta) {
[388]170                        METHOD = MTH;
[384]171                        beta = beta0;
172                }
173                //! Set support points from a pdf by drawing N samples
[388]174                void set_support (const epdf &overall, int N) {
175                        eSmp.set_statistics (&overall, N);
176                        Npoints = N;
[384]177                }
[388]178
[384]179                //! Destructor
180                virtual ~merger_base() {
181                        for (int i = 0; i < Nsources; i++) {
182                                delete dls (i);
183                                delete zdls (i);
[299]184                        }
[384]185                        if (DBG) delete dbg_file;
186                };
187                //!@}
[388]188
[384]189                //! \name Mathematical operations
190                //!@{
[388]191
[384]192                //!Merge given sources in given points
[395]193                virtual void merge () {
[388]194                        validate();
195
[384]196                        //check if sources overlap:
197                        bool OK = true;
198                        for (int i = 0;i < mpdfs.length(); i++) {
199                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty
200                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty
[310]201                        }
[384]202
203                        if (OK) {
204                                mat lW = zeros (mpdfs.length(), eSmp._w().length());
205
206                                vec emptyvec (0);
207                                for (int i = 0; i < mpdfs.length(); i++) {
208                                        for (int j = 0; j < eSmp._w().length(); j++) {
209                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec);
[299]210                                        }
211                                }
[384]212
[395]213                                vec w_nn=merge_points (lW);
214                                vec wtmp = exp (w_nn-max(w_nn));
[384]215                                //renormalize
216                                eSmp._w() = wtmp / sum (wtmp);
217                        } else {
[388]218                                it_error ("Sources are not compatible - use merger_mix");
[299]219                        }
[384]220                };
[176]221
[384]222
223                //! Merge log-likelihood values in points using method specified by parameter METHOD
224                vec merge_points (mat &lW);
[388]225
226
[384]227                //! sample from merged density
[192]228//! weight w is a
[384]229                vec mean() const {
230                        const Vec<double> &w = eSmp._w();
231                        const Array<vec> &S = eSmp._samples();
232                        vec tmp = zeros (dim);
233                        for (int i = 0; i < Npoints; i++) {
234                                tmp += w (i) * S (i);
[299]235                        }
[384]236                        return tmp;
237                }
238                mat covariance() const {
239                        const vec &w = eSmp._w();
240                        const Array<vec> &S = eSmp._samples();
[299]241
[384]242                        vec mea = mean();
[299]243
[404]244//                      cout << sum (w) << "," << w*w << endl;
[299]245
[384]246                        mat Tmp = zeros (dim, dim);
247                        for (int i = 0; i < Npoints; i++) {
248                                Tmp += w (i) * outer_product (S (i), S (i));
[299]249                        }
[384]250                        return Tmp -outer_product (mea, mea);
251                }
252                vec variance() const {
253                        const vec &w = eSmp._w();
254                        const Array<vec> &S = eSmp._samples();
[299]255
[384]256                        vec tmp = zeros (dim);
257                        for (int i = 0; i < Nsources; i++) {
258                                tmp += w (i) * pow (S (i), 2);
[299]259                        }
[384]260                        return tmp -pow (mean(), 2);
261                }
262                //!@}
[192]263
[384]264                //! \name Access to attributes
265                //! @{
266
267                //! Access function
268                eEmp& _Smp() {return eSmp;}
[388]269
270                //! load from setting
271                void from_setting (const Setting& set) {
272                        // get support
273                        // find which method to use
274                        string meth_str;
275                        UI::get<string> (meth_str, set, "method");
276                        if (!strcmp (meth_str.c_str(), "arithmetic"))
277                                set_method (ARITHMETIC);
278                        else {
279                                if (!strcmp (meth_str.c_str(), "geometric"))
280                                        set_method (GEOMETRIC);
281                                else if (!strcmp (meth_str.c_str(), "lognormal")) {
[392]282                                        set_method (LOGNORMAL);
[388]283                                        set.lookupValue( "beta",beta);
284                                }
285                        }
[395]286                        if (set.exists("dbg_file")){ 
287                                string dbg_file;
288                                UI::get<string> (dbg_file, set, "dbg_file");
289                                set_debug_file(dbg_file);
290                        }
291                        //validate() - not used
[388]292                }
293
294                void validate() {
295                        it_assert (eSmp._w().length() > 0, "Empty support, use set_support().");
296                        it_assert (dim == eSmp._samples() (0).length(), "Support points and rv are not compatible!");
297                        it_assert (isnamed(),"mergers must be named");
298                }
[384]299                //!@}
300};
[388]301UIREGISTER(merger_base);
[384]302
303class merger_mix : public merger_base
304{
305        protected:
306                //!Internal mixture of EF models
307                MixEF Mix;
308                //!Number of components in a mixture
309                int Ncoms;
[395]310                //! coefficient of resampling [0,1]
[384]311                double effss_coef;
[395]312                //! stop after niter iterations
313                int stop_niter;
[399]314               
315                //! default value for Ncoms
316                static const int DFLT_Ncoms;
317                //! default value for efss_coef;
318                static const double DFLT_effss_coef;
[384]319
320        public:
321                //!\name Constructors
322                //!@{
323                merger_mix () {};
[388]324                merger_mix (const Array<mpdf*> &S,bool own=false) {set_sources (S,own);};
[384]325                //! Set sources and prepare all internal structures
[388]326                void set_sources (const Array<mpdf*> &S, bool own) {
327                        merger_base::set_sources (S,own);
[384]328                        Nsources = S.length();
329                }
330                //! Set internal parameters used in approximation
[399]331                void set_parameters (int Ncoms0 = DFLT_Ncoms, double effss_coef0 = DFLT_effss_coef) {
[384]332                        Ncoms = Ncoms0;
333                        effss_coef = effss_coef0;
334                }
335                //!@}
[388]336
[384]337                //! \name Mathematical operations
338                //!@{
[388]339
[384]340                //!Merge values using mixture approximation
341                void merge ();
342
343                //! sample from the approximating mixture
344                vec sample () const { return Mix.posterior().sample();}
345                //! loglikelihood computed on mixture models
346                double evallog (const vec &dt) const {
347                        vec dtf = ones (dt.length() + 1);
348                        dtf.set_subvector (0, dt);
349                        return Mix.logpred (dtf);
350                }
351                //!@}
352
353                //!\name Access functions
354                //!@{
[192]355//! Access function
[384]356                MixEF& _Mix() {return Mix;}
357                //! Access function
358                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;}
[388]359                //! from_settings
[395]360                void from_setting(const Setting& set){
[388]361                        merger_base::from_setting(set);
362                        set.lookupValue("ncoms",Ncoms);
[395]363                        set.lookupValue("effss_coef",effss_coef);
364                        set.lookupValue("stop_niter",stop_niter);
[388]365                }
[384]366               
[388]367                //! @}
368
[384]369};
[388]370UIREGISTER(merger_mix);
[176]371
[254]372}
[176]373
374#endif // MER_H
Note: See TracBrowser for help on using the browser.