/*!
  \file
  \brief Probability distributions for Mixtures of pdfs
  \author Vaclav Smidl.

  -----------------------------------
  BDM++ - C++ library for Bayesian Decision Making under Uncertainty

  Using IT++ for numerical operations
  -----------------------------------
*/

#ifndef EMIX_H
#define EMIX_H

#define LOG2  0.69314718055995

#include "../shared_ptr.h"
#include "exp_family.h"

namespace bdm {

//this comes first because it is used inside emix!

/*! \brief Class representing ratio of two densities

which arise e.g. by applying the Bayes rule.
It represents density in the form:
\f[
f(rv|rvc) = \frac{f(rv,rvc)}{f(rvc)}
\f]
where \f$ f(rvc) = \int f(rv,rvc) d\ rv \f$.

In particular this type of arise by conditioning of a mixture model.

At present the only supported operation is evallogcond().
 */
class mratio: public pdf {
protected:
    //! Nominator in the form of pdf
    const epdf* nom;

    //!Denominator in the form of epdf
    shared_ptr<epdf> den;

    //!flag for destructor
    bool destroynom;
    //!datalink between conditional and nom
    datalink_m2e dl;
public:
    //!Default constructor. By default, the given epdf is not copied!
    //! It is assumed that this function will be used only temporarily.
    mratio ( const epdf* nom0, const RV &rv, bool copy = false ) : pdf ( ), dl ( ) {
        // adjust rv and rvc

        set_rv ( rv );
        dim = rv._dsize();

        rvc = nom0->_rv().subt ( rv );
        dimc = rvc._dsize();

        //prepare data structures
        if ( copy ) {
            bdm_error ( "todo" );
            // destroynom = true;
        } else {
            nom = nom0;
            destroynom = false;
        }
        bdm_assert_debug ( rvc.length() > 0, "Makes no sense to use this object!" );

        // build denominator
        den = nom->marginal ( rvc );
        dl.set_connection ( rv, rvc, nom0->_rv() );
    };

    double evallogcond ( const vec &val, const vec &cond ) {
        double tmp;
        vec nom_val ( dimension() + dimc );
        dl.pushup_cond ( nom_val, val, cond );
        tmp = exp ( nom->evallog ( nom_val ) - den->evallog ( cond ) );
        return tmp;
    }

    //! Returns a sample from the density conditioned on \c cond, \f$x \sim epdf(rv|cond)\f$. \param cond is numeric value of \c rv
    virtual vec samplecond ( const vec &cond ) NOT_IMPLEMENTED(0);

    //! Object takes ownership of nom and will destroy it
    void ownnom() {
        destroynom = true;
    }
    //! Default destructor
    ~mratio() {
        if ( destroynom ) {
            delete nom;
        }
    }


private:
    // not implemented
    mratio ( const mratio & );
    mratio &operator= ( const mratio & );
};

class emix; //forward

//! \brief Base class (interface) for mixtures
class emix_base : public epdf {
protected:
    //! reference to vector of weights
    vec &w;
    //! function returning ith component
    virtual const epdf * component(const int &i) const=0;

    virtual int no_coms() const=0;

public:

    emix_base(vec &w0): w(w0) {}

    void validate ();

    vec sample() const;

    vec mean() const;

    vec variance() const;

    double evallog ( const vec &val ) const;

    vec evallog_mat ( const mat &Val ) const;

    //! Auxiliary function that returns pdflog for each component
    mat evallog_coms ( const mat &Val ) const;

    shared_ptr<epdf> marginal ( const RV &rv ) const;
    //! Update already existing marginal density  \c target
    void marginal ( const RV &rv, emix &target ) const;
    shared_ptr<pdf> condition ( const RV &rv ) const;

    //Access methods
    //! returns a reference to the internal weights. Use with Care!
    vec& _w() {
        return w;
    }

    const vec& _w() const {
        return w;
    }
    //!access
    const epdf* _com(int i) const {
        return component(i);
    }

};

/*!
* \brief Mixture of epdfs

Density function:
\f[
f(x) = \sum_{i=1}^{n} w_{i} f_i(x), \quad \sum_{i=1}^n w_i = 1.
\f]
where \f$f_i(x)\f$ is any density on random variable \f$x\f$, called \a component,

*/
class emix : public emix_base {
protected:
    //! weights of the components
    vec weights;

    //! Component (epdfs)
    Array<shared_ptr<epdf> > Coms;

public:
    //! Default constructor
    emix ( ) : emix_base ( weights) { }

    const epdf* component(const int &i) const {
        return Coms(i).get();
    }
    void validate();


    int no_coms() const {
        return Coms.length();
    }

    /*! Create object from the following structure

    \code
    class = 'emix';

    pdfs = { list of any bdm::pdf offsprings };   % pdfs in the mixture, bdm::pdf::from_setting
    weights = [... ];                             % vector of weights of pdfs in the mixture
    --- inherited fields ---
    bdm::emix_base::from_setting
    \endcode

    */
    void from_setting ( const Setting &set );

    void to_setting  (Setting  &set) const;

    void set_rv ( const RV &rv ) {
        epdf::set_rv ( rv );
        for ( int i = 0; i < no_coms(); i++ ) {
            Coms( i )->set_rv ( rv );
        }
    }

    Array<shared_ptr<epdf> >& _Coms ( ) {
        return Coms;
    }
};
SHAREDPTR ( emix );
UIREGISTER ( emix );


/*! \brief Chain rule decomposition of epdf

Probability density in the form of Chain-rule decomposition:
\[
f(x_1,x_2,x_3) = f(x_1|x_2,x_3)f(x_2,x_3)f(x_3)
\]
Note that
*/
class mprod: public pdf {
private:
    Array<shared_ptr<pdf> > pdfs;

    //! Data link for each pdfs
    Array<shared_ptr<datalink_m2m> > dls;

public:
    //! \brief Default constructor
    mprod() { }

    /*!\brief Constructor from list of mFacs
    */
    mprod ( const Array<shared_ptr<pdf> > &mFacs ) {
        set_elements ( mFacs );
    }
    //! Set internal \c pdfs from given values
    void set_elements ( const Array<shared_ptr<pdf> > &mFacs );

    double evallogcond ( const vec &val, const vec &cond );

    vec evallogcond_mat ( const mat &Dt, const vec &cond );

    vec evallogcond_mat ( const Array<vec> &Dt, const vec &cond );

    //TODO smarter...
    vec samplecond ( const vec &cond ) {
        //! Ugly hack to help to discover if mpfs are not in proper order. Correct solution = check that explicitely.
        vec smp = std::numeric_limits<double>::infinity() * ones ( dimension() );
        vec smpi;
        // Hard assumption here!!! We are going backwards, to assure that samples that are needed from smp are already generated!
        for ( int i = ( pdfs.length() - 1 ); i >= 0; i-- ) {
            // generate contribution of this pdf
            smpi = pdfs ( i )->samplecond ( dls ( i )->get_cond ( smp , cond ) );
            // copy contribution of this pdf into smp
            dls ( i )->pushup ( smp, smpi );
        }
        return smp;
    }

    //! Create object from the following structure
    //! \code
    //! class='mprod';
    //! pdfs = { list of bdm::pdf };          % list of bdm::pdf offsprings stored in the order of chain rule, bdm::pdf::from_setting
    //! --- inherited fields ---
    //! bdm::pdf::from_setting
    //! \endcode
    //!@}
    void from_setting ( const Setting &set ) ;
    void to_setting  (Setting  &set) const;    
};
UIREGISTER ( mprod );
SHAREDPTR ( mprod );


//! \brief Base class (interface) for bdm::eprod
class eprod_base: public epdf {
protected:
    //! Array of indices
    Array<datalink*> dls;
    //! interface for a factor
public:
    virtual const epdf* factor(int i) const NOT_IMPLEMENTED(NULL);
    //!number of factors
    virtual const int no_factors() const NOT_IMPLEMENTED(0);
    //! Default constructor
	eprod_base () :  dls (0) {};
	eprod_base (const eprod_base &ep0) :  dls (ep0.dls) {};
	//! Set internal
    vec mean() const;

    vec variance() const;

    vec sample() const;

    double evallog ( const vec &val ) const;

    //!Destructor
    ~eprod_base() {
        for ( int i = 0; i < dls.length(); i++ ) {
            delete dls ( i );
        }
    }
    void validate() {
        epdf::validate();
        dls.set_length ( no_factors() );

        bool independent = true;
        dim = 0;
		rv = RV();
        for ( int i = 0; i < no_factors(); i++ ) {
            independent = rv.add ( factor ( i )->_rv() );
            dim += factor ( i )->dimension();
            bdm_assert_debug ( independent, "eprod:: given components are not independent." );
        };

        //
        int cumdim = 0;
        int dimi = 0;
        int i;
        for ( i = 0; i < no_factors(); i++ ) {
			if (!dls(i)){
				dls ( i ) = new datalink;
			}
            if ( isnamed() ) { // rvs are complete
                dls ( i )->set_connection ( factor ( i )->_rv() , rv );
            } else { //rvs are not reliable
                dimi = factor ( i )->dimension();
                dls ( i )->set_connection ( dimi, dim, linspace ( cumdim, cumdim + dimi - 1 ) );
                cumdim += dimi;
            }
        }

    }
};

//! \brief Product of independent epdfs. For dependent pdfs, use bdm::mprod.
class eprod: public eprod_base {
protected:
    Array<shared_ptr<epdf> > factors;
public:
    const epdf* factor(int i) const {
        return factors(i).get();
    }
    const int no_factors() const {
        return factors.length();
    }
    void set_parameters ( const Array<shared_ptr<epdf> > &epdfs0) {
        factors = epdfs0;
    }

    /*! Create object from the following structure

    \code
    class = 'eprod';
    pdfs = { list of any bdm::epdf offsprings };   % pdfs in the product, bdm::epdf::from_setting
    --- inherited fields ---
    bdm::eprod_base::from_setting
    \endcode

    */
    void from_setting(const Setting &set) {
        UI::get(factors,set,"pdfs",UI::compulsory);
    }
};
UIREGISTER(eprod);

//! \brief Internal class similar to eprod - factors are external pointers. To be used internally!  
class eprod_internal: public eprod_base {
protected:
    Array<epdf* > factors;
    const epdf* factor(int i) const {
        return factors(i);
    }
    const int no_factors() const {
        return factors.length();
    }
public:
    void set_parameters ( const Array<epdf *> &epdfs0) {
        factors = epdfs0;
    }
};

/*! \brief Mixture of pdfs with constant weights, all pdfs are of equal RV and RVC

*/
class mmix : public pdf {
protected:
    //! Component (pdfs)
    Array<shared_ptr<pdf> > Coms;
    //!weights of the components
    vec w;
public:
    //!Default constructor
    mmix() : Coms ( 0 ) { };

    double evallogcond ( const vec &dt, const vec &cond ) {
        double ll = 0.0;
        for ( int i = 0; i < Coms.length(); i++ ) {
            ll += Coms ( i )->evallogcond ( dt, cond );
        }
        return ll;
    }

    vec samplecond ( const vec &cond );

    //! }
    //! \endcode
    //!@}
   
    /*! Create object from the following structure
    \code
    class = 'mmix';
    pdfs = { list of components bdm::pdf };          % list of pdf offsprings, bdm::pdf::from_setting
    --- optional fields ---
    weights = [...];                                 % weights of pdfs in the mixture
    --- inherited fields ---
    bdm::pdf::from_setting
    \endcode
    \endcode
    If the optional fields are not given, they will be filled as follows:
    \code
    weights = 1/n * [1,1,1,...];
    \endcode
    */
    void from_setting ( const Setting &set );
    void     to_setting  (Setting  &set) const;
    virtual void validate();
};
SHAREDPTR ( mmix );
UIREGISTER ( mmix );


//! \brief Base class for all BM running as parallel update of internal BMs
class ProdBMBase : public BM {
protected :
    Array<vec_from_vec> bm_yt;
    Array<vec_from_2vec> bm_cond;

    //! \brief Internal class
    class eprod_bm : public eprod_base {
        ProdBMBase & pb;
    public :
		eprod_bm(ProdBMBase &pb0): pb(pb0) {}
		const epdf* factor(int i ) const {
            return &(pb.bm(i)->posterior());
        }
        const int no_factors() const {
            return pb.no_bms();
        }
    } est;
public:
	ProdBMBase():BM(),est(*this) {}
	ProdBMBase(const ProdBMBase &p0):BM(p0),bm_yt(p0.bm_yt),bm_cond(p0.bm_cond),est(*this) {
		est.validate();
	}
	virtual BM* bm(int i) NOT_IMPLEMENTED(NULL);
    virtual int no_bms() const {
        return 0;
    }
    const epdf& posterior() const {
        return est;
    }
    void set_prior(const epdf *pri) {
        const eprod_base* ep=dynamic_cast<const eprod_base*>(pri);
        if (ep) {
            bdm_assert(ep->no_factors()!=no_bms() , "Given prior has "+ num2str(ep->no_factors()) + " while this ProdBM has "+
                       num2str(no_bms()) + "BMs");
            for (int i=0; i<no_bms(); i++) {
                bm(i)->set_prior(ep->factor(i));
            }
        }
    }

    void validate() {
        est.validate();
        BM::validate();
        // set links
        bm_yt.set_length(no_bms());
        bm_cond.set_length(no_bms());

        //

        for (int i=0; i<no_bms(); i++) {
            yrv.add(bm(i)->_yrv());
            rvc.add(bm(i)->_rvc());
        }
        rvc=rvc.subt(yrv);

        dimy = yrv._dsize();
        dimc = rvc._dsize();

        for (int i=0; i<no_bms(); i++) {
			bm_yt(i).set_length(bm(i)->dimensiony());
            bm_yt(i).connect(bm(i)->_yrv(), yrv);
			bm_cond(i).set_length(bm(i)->dimensionc());
			bm_cond(i).connect(bm(i)->_rvc(), yrv, rvc);
        }
    }
    void bayes(const vec &dt, const vec &cond) {
        ll=0;
        for(int i=0; i<no_bms(); i++) {
            bm_yt(i).update(dt);
            bm_cond(i).update(dt,cond);
            bm(i)->bayes(bm_yt(i), bm_cond(i));
			ll+=bm(i)->_ll();
        }
    }
    vec samplepred( const vec &cond) {
		vec samp=zeros(dimy);
		
		for(int i=0; i<no_bms(); i++) {
			bm_cond(i).update(samp,cond);
			vec yi=bm(i)->samplepred(bm_cond(i));
			bm_yt(i)._dl().pushup(samp,yi);
		}
		return samp;
	}
	
};

class ProdBM: public ProdBMBase {
protected:
    Array<shared_ptr<BM> > BMs;
public:
	ProdBM():ProdBMBase(),BMs(){};
	ProdBM(const ProdBM &p0):ProdBMBase(p0),BMs(p0.BMs){est.validate();};
    ProdBM* _copy() const {return new ProdBM(*this);}
    virtual BM* bm(int i) {
        return BMs(i).get();
    }
    virtual int no_bms() const {
        return BMs.length();
    }
    void from_setting(const Setting &set) {
        BM::from_setting(set);
        UI::get(BMs,set,"BMs");
    }
    void to_setting(Setting &set) const {
        BM::to_setting(set);
        UI::save(BMs,set,"BMs");
    }
};
UIREGISTER(ProdBM);

//! \brief class for on-line model selection
class ModelComparator: public ProdBM {
protected:
    multiBM weights;
public:
    void bayes(const vec &yt, const vec &cond) {
        vec w_nn(no_bms());
        for (int i=0; i<no_bms(); i++) {
            bm(i)->bayes(yt,cond);
            w_nn(i) += bm(i)->_ll();
        }
        vec w=exp(w_nn-max(w_nn));
        weights.bayes(w/sum(w));
    }
    void validate() {
        ProdBM::validate();
        weights.validate();
    }

   /*! Create object from the following structure

    \code
    class = 'ModelComparator';
    --- optional fields ---
    frg = [...];                  % vector of weights 
    --- inherited fields ---
    bdm::ProdBM::from_setting
    \endcode
    */
    void from_setting(const Setting& set) {
        ProdBM::from_setting(set);
        UI::get(weights.frg, set, "frg",UI::optional);
    }

    void to_setting(Setting& set) const {
        ProdBM::to_setting(set);
        UI::save(weights.frg, set, "frg");
    }
};

}
#endif //MX_H
