00001
00029 #ifndef MOG_GENERIC_H
00030 #define MOG_GENERIC_H
00031
00032 #include <itpp/base/vec.h>
00033 #include <itpp/base/mat.h>
00034 #include <itpp/base/array.h>
00035
00036
00037 namespace itpp
00038 {
00039
00056 class MOG_generic
00057 {
00058
00059 public:
00060
00066 MOG_generic() { init(); }
00067
00071 MOG_generic(const std::string &name_in) { load(name_in); }
00072
00078 MOG_generic(const int &K_in, const int &D_in, bool full_in = false) { init(K_in, D_in, full_in); }
00079
00087 MOG_generic(Array<vec> &means_in, bool full_in = false) { init(means_in, full_in); }
00088
00095 MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); }
00096
00103 MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); }
00104
00106 virtual ~MOG_generic() { cleanup(); }
00107
00112 void init();
00113
00119 void init(const int &K_in, const int &D_in, bool full_in = false);
00120
00128 void init(Array<vec> &means_in, bool full_in = false);
00129
00136 void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in);
00137
00144 void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in);
00145
00150 virtual void cleanup();
00151
00153 bool is_valid() const { return valid; }
00154
00156 bool is_full() const { return full; }
00157
00159 int get_K() const { if (valid) return(K); else return(0); }
00160
00162 int get_D() const { if (valid) return(D); else return(0); }
00163
00165 vec get_weights() const { vec tmp; if (valid) { tmp = weights; } return tmp; }
00166
00168 Array<vec> get_means() const { Array<vec> tmp; if (valid) { tmp = means; } return tmp; }
00169
00171 Array<vec> get_diag_covs() const { Array<vec> tmp; if (valid && !full) { tmp = diag_covs; } return tmp; }
00172
00174 Array<mat> get_full_covs() const { Array<mat> tmp; if (valid && full) { tmp = full_covs; } return tmp; }
00175
00179 void set_means(Array<vec> &means_in);
00180
00184 void set_diag_covs(Array<vec> &diag_covs_in);
00185
00189 void set_full_covs(Array<mat> &full_covs_in);
00190
00194 void set_weights(vec &weights_in);
00195
00197 void set_means_zero();
00198
00200 void set_diag_covs_unity();
00201
00203 void set_full_covs_unity();
00204
00206 void set_weights_uniform();
00207
00213 void set_checks(bool do_checks_in) { do_checks = do_checks_in; }
00214
00218 void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; }
00219
00223 virtual void load(const std::string &name_in);
00224
00228 virtual void save(const std::string &name_in) const;
00229
00246 virtual void join(const MOG_generic &B_in);
00247
00255 virtual void convert_to_diag();
00256
00262 virtual void convert_to_full();
00263
00265 virtual double log_lhood_single_gaus(const vec &x_in, const int k);
00266
00268 virtual double log_lhood(const vec &x_in);
00269
00271 virtual double lhood(const vec &x_in);
00272
00274 virtual double avg_log_lhood(const Array<vec> &X_in);
00275
00276 protected:
00277
00279 bool do_checks;
00280
00282 bool valid;
00283
00285 bool full;
00286
00288 bool paranoid;
00289
00291 int K;
00292
00294 int D;
00295
00297 Array<vec> means;
00298
00300 Array<vec> diag_covs;
00301
00303 Array<mat> full_covs;
00304
00306 vec weights;
00307
00309 double log_max_K;
00310
00316 vec log_det_etc;
00317
00319 vec log_weights;
00320
00322 Array<mat> full_covs_inv;
00323
00325 Array<vec> diag_covs_inv_etc;
00326
00328 bool check_size(const vec &x_in) const;
00329
00331 bool check_size(const Array<vec> &X_in) const;
00332
00334 bool check_array_uniformity(const Array<vec> & A) const;
00335
00337 void set_means_internal(Array<vec> &means_in);
00339 void set_diag_covs_internal(Array<vec> &diag_covs_in);
00341 void set_full_covs_internal(Array<mat> &full_covs_in);
00343 void set_weights_internal(vec &_weigths);
00344
00346 void set_means_zero_internal();
00348 void set_diag_covs_unity_internal();
00350 void set_full_covs_unity_internal();
00352 void set_weights_uniform_internal();
00353
00355 void convert_to_diag_internal();
00357 void convert_to_full_internal();
00358
00360 virtual void setup_means();
00361
00363 virtual void setup_covs();
00364
00366 virtual void setup_weights();
00367
00369 virtual void setup_misc();
00370
00372 virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k);
00374 virtual double log_lhood_internal(const vec &x_in);
00376 virtual double lhood_internal(const vec &x_in);
00377
00378 private:
00379 vec tmpvecD;
00380 vec tmpvecK;
00381
00382 };
00383
00384 }
00385
00386 #endif // #ifndef MOG_GENERIC_H