/*!
  \file
  \brief Probability distributions for Mixtures of pdfs
  \author Vaclav Smidl.

  -----------------------------------
  BDM++ - C++ library for Bayesian Decision Making under Uncertainty

  Using IT++ for numerical operations
  -----------------------------------
*/

#ifndef EMIX_H
#define EMIX_H

#define LOG2  0.69314718055995

#include "../shared_ptr.h"
#include "exp_family.h"

namespace bdm {

//this comes first because it is used inside emix!

/*! \brief Class representing ratio of two densities
which arise e.g. by applying the Bayes rule.
It represents density in the form:
\f[
f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
\f]
where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.

In particular this type of arise by conditioning of a mixture model.

At present the only supported operation is evallogcond().
 */
class mratio: public mpdf {
protected:
	//! Nominator in the form of mpdf
	const epdf* nom;

	//!Denominator in the form of epdf
	shared_ptr<epdf> den;

	//!flag for destructor
	bool destroynom;
	//!datalink between conditional and nom
	datalink_m2e dl;
	//! dummy epdf that stores only rv and dim
	epdf iepdf;
public:
	//!Default constructor. By default, the given epdf is not copied!
	//! It is assumed that this function will be used only temporarily.
	mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : mpdf ( ), dl ( ),iepdf() {
		// adjust rv and rvc
		rvc = nom0->_rv().subt ( rv );
		dimc = rvc._dsize();
		set_ep ( iepdf );
		iepdf.set_parameters ( rv._dsize() );
		iepdf.set_rv ( rv );

		//prepare data structures
		if ( copy ) {
			it_error ( "todo" );
			destroynom = true;
		} else {
			nom = nom0;
			destroynom = false;
		}
		it_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );

		// build denominator
		den = nom->marginal ( rvc );
		dl.set_connection ( rv, rvc, nom0->_rv() );
	};
	double evallogcond ( const vec &val, const vec &cond ) {
		double tmp;
		vec nom_val ( dimension() + dimc );
		dl.pushup_cond ( nom_val, val, cond );
		tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
		return tmp;
	}
	//! Object takes ownership of nom and will destroy it
	void ownnom() {
		destroynom = true;
	}
	//! Default destructor
	~mratio() {
		if ( destroynom ) {
			delete nom;
		}
	}
};

/*!
* \brief Mixture of epdfs

Density function:
\f[
f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
\f]
where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,

*/
class emix : public epdf {
protected:
	//! weights of the components
	vec w;
	//! Component (epdfs)
	Array<shared_ptr<epdf> > Coms;

public:
	//!Default constructor
	emix ( ) : epdf ( ) {};
	//! Set weights \c w and components \c Coms
	//!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
	void set_parameters ( const vec &w, const Array<shared_ptr<epdf> > &Coms );

	vec sample() const;
	vec mean() const {
		int i;
		vec mu = zeros ( dim );
		for ( i = 0; i < w.length(); i++ ) {
			mu += w ( i ) * Coms ( i )->mean();
		}
		return mu;
	}
	vec variance() const {
		//non-central moment
		vec mom2 = zeros ( dim );
		for ( int i = 0; i < w.length(); i++ ) {
			mom2 += w ( i ) * ( Coms ( i )->variance() + pow ( Coms ( i )->mean(), 2 ) );
		}
		//central moment
		return mom2 - pow ( mean(), 2 );
	}
	double evallog ( const vec &val ) const {
		int i;
		double sum = 0.0;
		for ( i = 0; i < w.length(); i++ ) {
			sum += w ( i ) * exp ( Coms ( i )->evallog ( val ) );
		}
		if ( sum == 0.0 ) {
			sum = std::numeric_limits<double>::epsilon();
		}
		double tmp = log ( sum );
		it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
		return tmp;
	};
	vec evallog_m ( const mat &Val ) const {
		vec x = zeros ( Val.cols() );
		for ( int i = 0; i < w.length(); i++ ) {
			x += w ( i ) * exp ( Coms ( i )->evallog_m ( Val ) );
		}
		return log ( x );
	};
	//! Auxiliary function that returns pdflog for each component
	mat evallog_M ( const mat &Val ) const {
		mat X ( w.length(), Val.cols() );
		for ( int i = 0; i < w.length(); i++ ) {
			X.set_row ( i, w ( i ) *exp ( Coms ( i )->evallog_m ( Val ) ) );
		}
		return X;
	};

	shared_ptr<epdf> marginal ( const RV &rv ) const;
	//! Update already existing marginal density  \c target
	void marginal ( const RV &rv, emix &target ) const;
	shared_ptr<mpdf> condition ( const RV &rv ) const;

//Access methods
	//! returns a pointer to the internal mean value. Use with Care!
	vec& _w() {
		return w;
	}

	//!access function
	shared_ptr<epdf> _Coms ( int i ) {
		return Coms ( i );
	}

	void set_rv ( const RV &rv ) {
		epdf::set_rv ( rv );
		for ( int i = 0; i < Coms.length(); i++ ) {
			Coms ( i )->set_rv ( rv );
		}
	}
};
SHAREDPTR( emix );

/*!
* \brief Mixture of egiws

*/
class egiwmix : public egiw {
protected:
	//! weights of the components
	vec w;
	//! Component (epdfs)
	Array<egiw*> Coms;
	//!Flag if owning Coms
	bool destroyComs;
public:
	//!Default constructor
	egiwmix ( ) : egiw ( ) {};

	//! Set weights \c w and components \c Coms
	//!By default Coms are copied inside. Parameter \c copy can be set to false if Coms live externally. Use method ownComs() if Coms should be destroyed by the destructor.
	void set_parameters ( const vec &w, const Array<egiw*> &Coms, bool copy = false );

	//!return expected value
	vec mean() const;

	//!return a sample from the density
	vec sample() const;

	//!return the expected variance
	vec variance() const;

	// TODO!!! Defined to follow ANSI and/or for future development
	void mean_mat ( mat &M, mat&R ) const {};
	double evallog_nn ( const vec &val ) const {
		return 0;
	};
	double lognc () const {
		return 0;
	}

	shared_ptr<epdf> marginal ( const RV &rv ) const;
	void marginal ( const RV &rv, emix &target ) const;

//Access methods
	//! returns a pointer to the internal mean value. Use with Care!
	vec& _w() {
		return w;
	}
	virtual ~egiwmix() {
		if ( destroyComs ) {
			for ( int i = 0; i < Coms.length(); i++ ) {
				delete Coms ( i );
			}
		}
	}
	//! Auxiliary function for taking ownership of the Coms()
	void ownComs() {
		destroyComs = true;
	}

	//!access function
	egiw* _Coms ( int i ) {
		return Coms ( i );
	}

	void set_rv ( const RV &rv ) {
		egiw::set_rv ( rv );
		for ( int i = 0; i < Coms.length(); i++ ) {
			Coms ( i )->set_rv ( rv );
		}
	}

	//! Approximation of a GiW mix by a single GiW pdf
	egiw* approx();
};

/*! \brief Chain rule decomposition of epdf

Probability density in the form of Chain-rule decomposition:
\[
f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
\]
Note that
*/
class mprod: public mpdf {
private:
	Array<shared_ptr<mpdf> > mpdfs;

	//! Data link for each mpdfs
	Array<shared_ptr<datalink_m2m> > dls;

protected:
	//! dummy epdf used only as storage for RV and dim
	epdf iepdf;

public:
	//! \brief Default constructor
	mprod() { }

	/*!\brief Constructor from list of mFacs
	*/
	mprod ( const Array<shared_ptr<mpdf> > &mFacs ) {
		set_elements ( mFacs );
	}
	//! Set internal \c mpdfs from given values
	void set_elements (const Array<shared_ptr<mpdf> > &mFacs );

	double evallogcond ( const vec &val, const vec &cond ) {
		int i;
		double res = 0.0;
		for ( i = mpdfs.length() - 1; i >= 0; i-- ) {
			/*			if ( mpdfs(i)->_rvc().count() >0) {
							mpdfs ( i )->condition ( dls ( i )->get_cond ( val,cond ) );
						}
						// add logarithms
						res += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );*/
			res += mpdfs ( i )->evallogcond (
			           dls ( i )->pushdown ( val ),
			           dls ( i )->get_cond ( val, cond )
			       );
		}
		return res;
	}
	vec evallogcond_m ( const mat &Dt, const vec &cond ) {
		vec tmp ( Dt.cols() );
		for ( int i = 0; i < Dt.cols(); i++ ) {
			tmp ( i ) = evallogcond ( Dt.get_col ( i ), cond );
		}
		return tmp;
	};
	vec evallogcond_m ( const Array<vec> &Dt, const vec &cond ) {
		vec tmp ( Dt.length() );
		for ( int i = 0; i < Dt.length(); i++ ) {
			tmp ( i ) = evallogcond ( Dt ( i ), cond );
		}
		return tmp;
	};


	//TODO smarter...
	vec samplecond ( const vec &cond ) {
		//! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
		vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
		vec smpi;
		// Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
		for ( int i = ( mpdfs.length() - 1 ); i >= 0; i-- ) {
			// generate contribution of this mpdf
			smpi = mpdfs(i)->samplecond(dls ( i )->get_cond ( smp , cond ));			
			// copy contribution of this pdf into smp
			dls ( i )->pushup ( smp, smpi );
		}
		return smp;
	}

	//! Load from structure with elements:
	//!  \code
	//! { class='mprod';
	//!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule
	//! }
	//! \endcode
	//!@}
	void from_setting ( const Setting &set ) {
		Array<shared_ptr<mpdf> > atmp; //temporary Array
		UI::get ( atmp, set, "mpdfs", UI::compulsory );
		set_elements ( atmp );
	}

private:
	// not implemented
	mprod ( const mprod & );
	mprod &operator=( const mprod & );
};
UIREGISTER ( mprod );
SHAREDPTR ( mprod );

//! Product of independent epdfs. For dependent pdfs, use mprod.
class eprod: public epdf {
protected:
	//! Components (epdfs)
	Array<const epdf*> epdfs;
	//! Array of indeces
	Array<datalink*> dls;
public:
	//! Default constructor
	eprod () : epdfs ( 0 ), dls ( 0 ) {};
	//! Set internal 
	void set_parameters ( const Array<const epdf*> &epdfs0, bool named = true ) {
		epdfs = epdfs0;//.set_length ( epdfs0.length() );
		dls.set_length ( epdfs.length() );

		bool independent = true;
		if ( named ) {
			for ( int i = 0; i < epdfs.length(); i++ ) {
				independent = rv.add ( epdfs ( i )->_rv() );
				it_assert_debug ( independent == true, "eprod:: given components are not independent." );
			}
			dim = rv._dsize();
		} else {
			dim = 0;
			for ( int i = 0; i < epdfs.length(); i++ ) {
				dim += epdfs ( i )->dimension();
			}
		}
		//
		int cumdim = 0;
		int dimi = 0;
		int i;
		for ( i = 0; i < epdfs.length(); i++ ) {
			dls ( i ) = new datalink;
			if ( named ) {
				dls ( i )->set_connection ( epdfs ( i )->_rv() , rv );
			} else {
				dimi = epdfs ( i )->dimension();
				dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
				cumdim += dimi;
			}
		}
	}

	vec mean() const {
		vec tmp ( dim );
		for ( int i = 0; i < epdfs.length(); i++ ) {
			vec pom = epdfs ( i )->mean();
			dls ( i )->pushup ( tmp, pom );
		}
		return tmp;
	}
	vec variance() const {
		vec tmp ( dim ); //second moment
		for ( int i = 0; i < epdfs.length(); i++ ) {
			vec pom = epdfs ( i )->mean();
			dls ( i )->pushup ( tmp, pow ( pom, 2 ) );
		}
		return tmp - pow ( mean(), 2 );
	}
	vec sample() const {
		vec tmp ( dim );
		for ( int i = 0; i < epdfs.length(); i++ ) {
			vec pom = epdfs ( i )->sample();
			dls ( i )->pushup ( tmp, pom );
		}
		return tmp;
	}
	double evallog ( const vec &val ) const {
		double tmp = 0;
		for ( int i = 0; i < epdfs.length(); i++ ) {
			tmp += epdfs ( i )->evallog ( dls ( i )->pushdown ( val ) );
		}
		it_assert_debug ( std::isfinite ( tmp ), "Infinite" );
		return tmp;
	}
	//!access function
	const epdf* operator () ( int i ) const {
		it_assert_debug ( i < epdfs.length(), "wrong index" );
		return epdfs ( i );
	}

	//!Destructor
	~eprod() {
		for ( int i = 0; i < epdfs.length(); i++ ) {
			delete dls ( i );
		}
	}
};


/*! \brief Mixture of mpdfs with constant weights, all mpdfs are of equal RV and RVC

*/
class mmix : public mpdf {
protected:
	//! Component (mpdfs)
	Array<shared_ptr<mpdf> > Coms;
	//!weights of the components
	vec w;
	//! dummy epdfs
	epdf dummy_epdf;
public:
	//!Default constructor
	mmix() : Coms(0), dummy_epdf() { set_ep(dummy_epdf);	}

	//! Set weights \c w and components \c R
	void set_parameters ( const vec &w0, const Array<shared_ptr<mpdf> > &Coms0 ) {
		//!\todo check if all components are OK
		Coms = Coms0;
		w=w0;	

		if (Coms0.length()>0){
			set_rv(Coms(0)->_rv());
			dummy_epdf.set_parameters(Coms(0)->_rv()._dsize());
			set_rvc(Coms(0)->_rvc());
			dimc = rvc._dsize();
		}
	}
	double evallogcond (const vec &dt, const vec &cond) {
		double ll=0.0;
		for (int i=0;i<Coms.length();i++){
			ll+=Coms(i)->evallogcond(dt,cond);
		}
		return ll;
	}

	vec samplecond (const vec &cond);

};

}
#endif //MX_H
