/*!
  \file
  \brief Probability distributions for Mixtures of pdfs
  \author Vaclav Smidl.

  -----------------------------------
  BDM++ - C++ library for Bayesian Decision Making under Uncertainty

  Using IT++ for numerical operations
  -----------------------------------
*/

#ifndef MX_H
#define MX_H

#include "libBM.h"
#include "libEF.h"
//#include <std>

using namespace itpp;

/*!
* \brief Mixture of epdfs

Density function:
\f[
f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
\f]
where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,

*/
class emix : public epdf {
	protected:
		//! weights of the components
		vec w;
		//! Component (epdfs)
		Array<epdf*> Coms;
	public:
		//!Default constructor
		emix(RV &rv) : epdf(rv) {};
		//! Set weights \c w and components \c R
		void set_parameters(const vec &w, const Array<epdf*> &Coms);

		vec sample() const;
		vec mean() const {
			int i; vec mu = zeros(rv.count());
			for (i = 0;i < w.length();i++) {mu += w(i) * Coms(i)->mean(); }
			return mu;
		}
		double evalpdflog(const vec &val) const {
			int i;
			double sum = 0.0;
			for (i = 0;i < w.length();i++) {sum += w(i) * Coms(i)->evalpdflog(val);}
			return log(sum);
		};

//Access methods
		//! returns a pointer to the internal mean value. Use with Care!
		vec& _w() {return w;}
};

/*! \brief Chain rule decomposition of epdf

Probability density in the form of Chain-rule decomposition:
\[
f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
\]
Note that
*/
class eprod: public epdf {
	protected:
		//
		int n;
		// pointers to epdfs
		Array<epdf*> epdfs;
		Array<mpdf*> mpdfs;
		//
		Array<ivec> rvinds;
		Array<ivec> rvcinds;
		//! Indicate independence of its factors
		bool independent;
		//! Indicate internal creation of mpdfs which must be destroyed
		bool intermpdfs;
	public:
		//!Constructor from list of eFacs or list of mFacs
		eprod(Array<mpdf*> mFacs): epdf(RV()), n(mFacs.length()), epdfs(n), mpdfs(mFacs), rvinds(n), rvcinds(n) {
			int i;
			intermpdfs = false;
			for (i = 0;i < n;i++) {
				rv.add(mpdfs(i)->_rv()); //add rv to common rvs.
				epdfs(i) = &(mpdfs(i)->_epdf()); // add pointer to epdf
			};
			independent = true;
			//test rvc of mpdfs and fill rvinds
			for (i = 0;i < n;i++) {
				// find ith rv in common rv
				rvinds(i) = mpdfs(i)->_rv().dataind(rv);
				// find ith rvc in common rv
				rvcinds(i) = mpdfs(i)->_rvc().dataind(rv);
				if (rvcinds(i).length()>0) {independent = false;}
			}

		};
		eprod(Array<epdf*> eFacs): epdf(RV()), n(eFacs.length()), epdfs(eFacs), mpdfs(n), rvinds(n), rvcinds(n) {
			int i;
			intermpdfs = true;
			for (i = 0;i < n;i++) {
				if (rv.add(eFacs(i)->_rv())) {it_error("Incompatible eFacs.rv!");} //add rv to common rvs.
				mpdfs(i) = new mepdf(*(epdfs(i)));
				epdfs(i) = &(mpdfs(i)->_epdf()); // add pointer to epdf
			};
			independent = true;
			//test rvc of mpdfs and fill rvinds
			for (i = 0;i < n;i++) {
				// find ith rv in common rv
				rvinds(i) = mpdfs(i)->_rv().dataind(rv);
			}
		};

		double evalpdflog(const vec &val) const {
			int i;
			double res = 0.0;
			for (i = n - 1;i > 0;i++) {
				if (rvcinds(i).length() > 0)
					{mpdfs(i)->condition(val(rvcinds(i)));}
				// add logarithms
				res += epdfs(i)->evalpdflog(val(rvinds(i)));
			}
		}
		vec sample () const{
			vec smp=zeros(rv.count());
			for (int i = (n - 1);i >= 0;i--) {
				if (rvcinds(i).length() > 0)
					{mpdfs(i)->condition(smp(rvcinds(i)));}
				set_subvector(smp,rvinds(i), epdfs(i)->sample());
			}			
			return smp;
		}
		vec mean() const{
			vec tmp(rv.count());
			if (independent) {
				for (int i=0;i<n;i++) {
					vec pom = epdfs(i)->mean();
					set_subvector(tmp,rvinds(i), pom);
				}
			}
			else {
				int N=50*rv.count();
				it_warning("eprod.mean() computed by sampling");
				tmp = zeros(rv.count());
				for (int i=0;i<N;i++){ tmp += sample();}
				tmp /=N;
			}
			return tmp;
		}
		~eprod(){if (intermpdfs) {for (int i=0;i<n;i++){delete mpdfs(i);}}};
};

/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal type

*/
class mmix : public mpdf {
	protected:
		//! Component (epdfs)
		Array<mpdf*> Coms;
		//!Internal epdf
		emix Epdf;
	public:
		//!Default constructor
		mmix(RV &rv, RV &rvc) : mpdf(rv, rvc), Epdf(rv) {ep = &Epdf;};
		//! Set weights \c w and components \c R
		void set_parameters(const vec &w, const Array<mpdf*> &Coms) {
			Array<epdf*> Eps(Coms.length());

			for (int i = 0;i < Coms.length();i++) {
				Eps(i) = & (Coms(i)->_epdf());
			}
			Epdf.set_parameters(w, Eps);
		};

		void condition(const vec &cond) {
			for (int i = 0;i < Coms.length();i++) {Coms(i)->condition(cond);}
		};
};
#endif //MX_H
