00001
00029 #ifndef MOG_DIAG_H
00030 #define MOG_DIAG_H
00031
00032 #include <itpp/stat/mog_generic.h>
00033
00034
00035 namespace itpp
00036 {
00037
00054 class MOG_diag : public MOG_generic
00055 {
00056
00057 public:
00058
00064 MOG_diag() { zero_all_ptrs(); init(); }
00065
00069 MOG_diag(const std::string &name) { zero_all_ptrs(); load(name); }
00070
00076 MOG_diag(const int &K_in, const int &D_in, bool full_in = false) { zero_all_ptrs(); init(K_in, D_in, full_in); }
00077
00084 MOG_diag(Array<vec> &means_in, bool) { zero_all_ptrs(); init(means_in, false); }
00085
00092 MOG_diag(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { zero_all_ptrs(); init(means_in, diag_covs_in, weights_in); }
00093
00101 MOG_diag(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { zero_all_ptrs(); init(means_in, full_covs_in, weights_in); convert_to_diag(); }
00102
00104 ~MOG_diag() { cleanup(); }
00105
00110 void cleanup() { free_all_ptrs(); MOG_generic::cleanup(); }
00111
00117 void load(const std::string &name_in);
00118
00120 void convert_to_full() {};
00121
00123 double log_lhood_single_gaus(const double * c_x_in, const int k) const;
00124
00126 double log_lhood_single_gaus(const vec &x_in, const int k) const;
00127
00129 double log_lhood(const double * c_x_in);
00130
00132 double log_lhood(const vec &x_in);
00133
00135 double lhood(const double * c_x_in);
00136
00138 double lhood(const vec &x_in);
00139
00141 double avg_log_lhood(const double ** c_x_in, int N);
00142
00144 double avg_log_lhood(const Array<vec> & X_in);
00145
00146 protected:
00147
00148 void setup_means();
00149 void setup_covs();
00150 void setup_weights();
00151 void setup_misc();
00152
00154 double log_lhood_single_gaus_internal(const double * c_x_in, const int k) const;
00156 double log_lhood_single_gaus_internal(const vec &x_in, const int k) const;
00158 double log_lhood_internal(const double * c_x_in);
00160 double log_lhood_internal(const vec &x_in);
00162 double lhood_internal(const double * c_x_in);
00164 double lhood_internal(const vec &x_in);
00165
00167 double ** enable_c_access(Array<vec> & A_in);
00168
00170 int ** enable_c_access(Array<ivec> & A_in);
00171
00173 double * enable_c_access(vec & v_in);
00174
00176 int * enable_c_access(ivec & v_in);
00177
00179 double ** disable_c_access(double ** A_in);
00180
00182 int ** disable_c_access(int ** A_in);
00183
00185 double * disable_c_access(double * v_in);
00186
00188 int * disable_c_access(int * v_in);
00189
00191 void zero_all_ptrs();
00193 void free_all_ptrs();
00194
00196 double ** c_means;
00197
00199 double ** c_diag_covs;
00200
00202 double ** c_diag_covs_inv_etc;
00203
00205 double * c_weights;
00206
00208 double * c_log_weights;
00209
00211 double * c_log_det_etc;
00212
00213 private:
00214
00215 vec tmpvecK;
00216 double * c_tmpvecK;
00217
00218 };
00219
00220 }
00221
00222 #endif // #ifndef MOG_DIAG_H
00223