root/bdm/stat/emix.h @ 148

Revision 148, 4.6 kB (checked in by mido, 16 years ago)

drobny patch

  • Property svn:eol-style set to native
Line 
1/*!
2  \file
3  \brief Probability distributions for Mixtures of pdfs
4  \author Vaclav Smidl.
5
6  -----------------------------------
7  BDM++ - C++ library for Bayesian Decision Making under Uncertainty
8
9  Using IT++ for numerical operations
10  -----------------------------------
11*/
12
13#ifndef MX_H
14#define MX_H
15
16#include "libBM.h"
17#include "libEF.h"
18//#include <std>
19
20using namespace itpp;
21
22/*!
23* \brief Mixture of epdfs
24
25Density function:
26\f[
27f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
28\f]
29where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,
30
31*/
32class emix : public epdf {
33        protected:
34                //! weights of the components
35                vec w;
36                //! Component (epdfs)
37                Array<epdf*> Coms;
38        public:
39                //!Default constructor
40                emix(RV &rv) : epdf(rv) {};
41                //! Set weights \c w and components \c R
42                void set_parameters(const vec &w, const Array<epdf*> &Coms);
43
44                vec sample() const;
45                vec mean() const {
46                        int i; vec mu = zeros(rv.count());
47                        for (i = 0;i < w.length();i++) {mu += w(i) * Coms(i)->mean(); }
48                        return mu;
49                }
50                double evalpdflog(const vec &val) const {
51                        int i;
52                        double sum = 0.0;
53                        for (i = 0;i < w.length();i++) {sum += w(i) * Coms(i)->evalpdflog(val);}
54                        return log(sum);
55                };
56
57//Access methods
58                //! returns a pointer to the internal mean value. Use with Care!
59                vec& _w() {return w;}
60};
61
62/*! \brief Chain rule decomposition of epdf
63
64Probability density in the form of Chain-rule decomposition:
65\[
66f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
67\]
68Note that
69*/
70class eprod: public epdf {
71        protected:
72                //
73                int n;
74                // pointers to epdfs
75                Array<epdf*> epdfs;
76                Array<mpdf*> mpdfs;
77                //
78                Array<ivec> rvinds;
79                Array<ivec> rvcinds;
80                //! Indicate independence of its factors
81                bool independent;
82                //! Indicate internal creation of mpdfs which must be destroyed
83                bool intermpdfs;
84        public:
85                //!Constructor from list of eFacs or list of mFacs
86                eprod(Array<mpdf*> mFacs): epdf(RV()), n(mFacs.length()), epdfs(n), mpdfs(mFacs), rvinds(n), rvcinds(n) {
87                        int i;
88                        intermpdfs = false;
89                        for (i = 0;i < n;i++) {
90                                rv.add(mpdfs(i)->_rv()); //add rv to common rvs.
91                                epdfs(i) = &(mpdfs(i)->_epdf()); // add pointer to epdf
92                        };
93                        independent = true;
94                        //test rvc of mpdfs and fill rvinds
95                        for (i = 0;i < n;i++) {
96                                // find ith rv in common rv
97                                rvinds(i) = mpdfs(i)->_rv().dataind(rv);
98                                // find ith rvc in common rv
99                                rvcinds(i) = mpdfs(i)->_rvc().dataind(rv);
100                                if (rvcinds(i).length()>0) {independent = false;}
101                        }
102
103                };
104                eprod(Array<epdf*> eFacs): epdf(RV()), n(eFacs.length()), epdfs(eFacs), mpdfs(n), rvinds(n), rvcinds(n) {
105                        int i;
106                        intermpdfs = true;
107                        for (i = 0;i < n;i++) {
108                                if (rv.add(eFacs(i)->_rv())) {it_error("Incompatible eFacs.rv!");} //add rv to common rvs.
109                                mpdfs(i) = new mepdf(*(epdfs(i)));
110                                epdfs(i) = &(mpdfs(i)->_epdf()); // add pointer to epdf
111                        };
112                        independent = true;
113                        //test rvc of mpdfs and fill rvinds
114                        for (i = 0;i < n;i++) {
115                                // find ith rv in common rv
116                                rvinds(i) = mpdfs(i)->_rv().dataind(rv);
117                        }
118                };
119
120                double evalpdflog(const vec &val) const {
121                        int i;
122                        double res = 0.0;
123                        for (i = n - 1;i > 0;i++) {
124                                if (rvcinds(i).length() > 0)
125                                        {mpdfs(i)->condition(val(rvcinds(i)));}
126                                // add logarithms
127                                res += epdfs(i)->evalpdflog(val(rvinds(i)));
128                        }
129                        return res;
130                }
131                vec sample () const{
132                        vec smp=zeros(rv.count());
133                        for (int i = (n - 1);i >= 0;i--) {
134                                if (rvcinds(i).length() > 0)
135                                        {mpdfs(i)->condition(smp(rvcinds(i)));}
136                                set_subvector(smp,rvinds(i), epdfs(i)->sample());
137                        }                       
138                        return smp;
139                }
140                vec mean() const{
141                        vec tmp(rv.count());
142                        if (independent) {
143                                for (int i=0;i<n;i++) {
144                                        vec pom = epdfs(i)->mean();
145                                        set_subvector(tmp,rvinds(i), pom);
146                                }
147                        }
148                        else {
149                                int N=50*rv.count();
150                                it_warning("eprod.mean() computed by sampling");
151                                tmp = zeros(rv.count());
152                                for (int i=0;i<N;i++){ tmp += sample();}
153                                tmp /=N;
154                        }
155                        return tmp;
156                }
157                ~eprod(){if (intermpdfs) {for (int i=0;i<n;i++){delete mpdfs(i);}}};
158};
159
160/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type
161
162*/
163class mmix : public mpdf {
164        protected:
165                //! Component (epdfs)
166                Array<mpdf*> Coms;
167                //!Internal epdf
168                emix Epdf;
169        public:
170                //!Default constructor
171                mmix(RV &rv, RV &rvc) : mpdf(rv, rvc), Epdf(rv) {ep = &Epdf;};
172                //! Set weights \c w and components \c R
173                void set_parameters(const vec &w, const Array<mpdf*> &Coms) {
174                        Array<epdf*> Eps(Coms.length());
175
176                        for (int i = 0;i < Coms.length();i++) {
177                                Eps(i) = & (Coms(i)->_epdf());
178                        }
179                        Epdf.set_parameters(w, Eps);
180                };
181
182                void condition(const vec &cond) {
183                        for (int i = 0;i < Coms.length();i++) {Coms(i)->condition(cond);}
184                };
185};
186#endif //MX_H
Note: See TracBrowser for help on using the browser.