00001
00029 #ifndef GMM_H
00030 #define GMM_H
00031
00032 #include <itpp/base/mat.h>
00033
00034
00035 namespace itpp
00036 {
00037
00039
00045 class GMM
00046 {
00047 public:
00048 GMM();
00049 GMM(int nomix, int dim);
00050 GMM(std::string filename);
00051 void init_from_vq(const vec &codebook, int dim);
00052
00053 void init(const vec &w_in, const mat &m_in, const mat &sigma_in);
00054 void load(std::string filename);
00055 void save(std::string filename);
00056 void set_weight(const vec &weights, bool compflag = true);
00057 void set_weight(int i, double weight, bool compflag = true);
00058 void set_mean(const mat &m_in);
00059 void set_mean(const vec &means, bool compflag = true);
00060 void set_mean(int i, const vec &means, bool compflag = true);
00061 void set_covariance(const mat &sigma_in);
00062 void set_covariance(const vec &covariances, bool compflag = true);
00063 void set_covariance(int i, const vec &covariances, bool compflag = true);
00064 int get_no_mixtures();
00065 int get_no_gaussians() const { return M; }
00066 int get_dimension();
00067 vec get_weight();
00068 double get_weight(int i);
00069 vec get_mean();
00070 vec get_mean(int i);
00071 vec get_covariance();
00072 vec get_covariance(int i);
00073 void marginalize(int d_new);
00074 void join(const GMM &newgmm);
00075 void clear();
00076 double likelihood(const vec &x);
00077 double likelihood_aposteriori(const vec &x, int mixture);
00078 vec likelihood_aposteriori(const vec &x);
00079 vec draw_sample();
00080 protected:
00081 vec m, sigma, w;
00082 int M, d;
00083 private:
00084 void compute_internals();
00085 vec normweight, normexp;
00086 };
00087
00088 inline void GMM::set_weight(const vec &weights, bool compflag) {w = weights; if (compflag) compute_internals(); }
00089 inline void GMM::set_weight(int i, double weight, bool compflag) {w(i) = weight; if (compflag) compute_internals(); }
00090 inline void GMM::set_mean(const vec &means, bool compflag) {m = means; if (compflag) compute_internals(); }
00091 inline void GMM::set_covariance(const vec &covariances, bool compflag) {sigma = covariances; if (compflag) compute_internals(); }
00092 inline int GMM::get_dimension() {return d;}
00093 inline vec GMM::get_weight() {return w;}
00094 inline double GMM::get_weight(int i) {return w(i);}
00095 inline vec GMM::get_mean() {return m;}
00096 inline vec GMM::get_mean(int i) {return m.mid(i*d, d);}
00097 inline vec GMM::get_covariance() {return sigma;}
00098 inline vec GMM::get_covariance(int i) {return sigma.mid(i*d, d);}
00099
00100 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER = 30, bool VERBOSE = true);
00101
00103
00104 }
00105
00106 #endif // #ifndef GMM_H