/*!
  \file
  \brief Bayesian Filtering using stochastic sampling (Particle Filters)
  \author Vaclav Smidl.

  -----------------------------------
  BDM++ - C++ library for Bayesian Decision Making under Uncertainty

  Using IT++ for numerical operations
  -----------------------------------
*/

#ifndef PARTICLES_H
#define PARTICLES_H


#include "../estim/arx_ext.h"
#include "../stat/emix.h"

namespace bdm {

//! \brief Abstract class for Marginalized Particles
class MarginalizedParticleBase : public BM {
protected:
    //! discrte particle
    dirac est_emp;
    //! internal Bayes Model
    shared_ptr<BM> bm;

    //! \brief Internal class for custom posterior - product of empirical and exact part
    class eprod_2:public eprod_base {
    protected:
        MarginalizedParticleBase &mp;
    public:
        eprod_2(MarginalizedParticleBase &m):mp(m) {}
        const epdf* factor(int i) const {
            return (i==0) ? &mp.bm->posterior() : &mp.est_emp;
        }
        const int no_factors() const {
            return 2;
        }
    } est;

public:
    MarginalizedParticleBase():est(*this) {};
    MarginalizedParticleBase(const MarginalizedParticleBase &m2):BM(m2),est(*this) {
        bm = m2.bm->_copy();
        est_emp = m2.est_emp;
        est.validate();
        validate();
    };
    void bayes(const vec &dt, const vec &cond) NOT_IMPLEMENTED_VOID;

    const eprod_2& posterior() const {
        return est;
    }

    void set_prior(const epdf *pdf0) {
        const eprod *ep=dynamic_cast<const eprod*>(pdf0);
        if (ep) { // full prior
            bdm_assert(ep->no_factors()==2,"Incompatible prod");
            bm->set_prior(ep->factor(0));
            est_emp.set_point(ep->factor(1)->sample());
        } else {
            // assume prior is only for emp;
            est_emp.set_point(pdf0->sample());
        }
    }


    /*! Create object from the following structure

    \code
    class = "MarginalizedParticleBase";
    bm = configuration of bdm::BM;          % any offspring of BM, bdm::BM::from_setting
    --- inherited fields ---
    bdm::BM::from_setting
    \endcode
    */
    void from_setting(const Setting &set) {
        BM::from_setting ( set );
        bm = UI::build<BM> ( set, "bm", UI::compulsory );
    }
    void validate() {
        BM::validate();
        //est.validate(); --pdfs not known
        bdm_assert(bm,"Internal BM is not given");
    }
};

//! \brief Particle with marginalized subspace, used in PF 
class MarginalizedParticle : public MarginalizedParticleBase {
protected:
    //! pdf with for transitional par
    shared_ptr<pdf> par; // pdf for non-linear part
    //! link from this to bm
    shared_ptr<datalink_part> cond2bm;
    //! link from cond to par
    shared_ptr<datalink_part> cond2par;
    //! link from emp 2 par
    shared_ptr<datalink_part> emp2bm;
    //! link from emp 2 par
    shared_ptr<datalink_part> emp2par;

public:
    BM* _copy() const {
        return new MarginalizedParticle(*this);
    };
    void bayes(const vec &dt, const vec &cond) {
        vec par_cond(par->dimensionc());
        cond2par->filldown(cond,par_cond); // copy ut
        emp2par->filldown(est_emp._point(),par_cond); // copy xt-1

        //sample new particle
        est_emp.set_point(par->samplecond(par_cond));
        //if (evalll)
        vec bm_cond(bm->dimensionc());
        cond2bm->filldown(cond, bm_cond);// set e.g. ut
        emp2bm->filldown(est_emp._point(), bm_cond);// set e.g. ut
        bm->bayes(dt, bm_cond);
        ll=bm->_ll();
    }

    /*! Create object from the following structure

    \code
    class = "MarginalizedParticle";
    parameter_pdf = configuration of bdm::epdf;          % any offspring of epdf, bdm::epdf::from_setting
    --- inherited fields ---
    bdm::MarginalizedParticleBase::from_setting
    \endcode
    */    
    void from_setting(const Setting &set) {
        MarginalizedParticleBase::from_setting ( set );
        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
    }

    void to_setting(Setting &set) {
        MarginalizedParticleBase::to_setting(set);
        UI::save(par,set,"parameter_pdf");
    }
    void validate() {
        MarginalizedParticleBase::validate();
        est_emp.set_rv(par->_rv());
        if (est_emp.point.length()!=par->dimension())
            est_emp.set_point(zeros(par->dimension()));
        est.validate();

        yrv = bm->_yrv();
        dimy = bm->dimensiony();
        set_rv( concat(bm->_rv(), par->_rv()));
        set_dim( par->dimension()+bm->dimension());

        rvc = par->_rvc();
        rvc.add(bm->_rvc());
        rvc=rvc.subt(par->_rv());
        rvc=rvc.subt(par->_rv().copy_t(-1));
        rvc=rvc.subt(bm->_rv().copy_t(-1)); //

        cond2bm=new datalink_part;
        cond2par=new datalink_part;
        emp2bm  =new datalink_part;
        emp2par =new datalink_part;
        cond2bm->set_connection(bm->_rvc(), rvc);
        cond2par->set_connection(par->_rvc(), rvc);
        emp2bm->set_connection(bm->_rvc(), par->_rv());
        emp2par->set_connection(par->_rvc(), par->_rv().copy_t(-1));

        dimc = rvc._dsize();
    };
};
UIREGISTER(MarginalizedParticle);

//! Internal class which is used in PF
class BootstrapParticle : public BM {
    dirac est;
    shared_ptr<pdf> par;
    shared_ptr<pdf> obs;
    shared_ptr<datalink_part> cond2par;
    shared_ptr<datalink_part> cond2obs;
    shared_ptr<datalink_part> xt2obs;
    shared_ptr<datalink_part> xtm2par;
public:
    BM* _copy() const {
        return new BootstrapParticle(*this);
    };
    void bayes(const vec &dt, const vec &cond) {
        vec par_cond(par->dimensionc());
        cond2par->filldown(cond,par_cond); // copy ut
        xtm2par->filldown(est._point(),par_cond); // copy xt-1

        //sample new particle
        est.set_point(par->samplecond(par_cond));
        //if (evalll)
        vec obs_cond(obs->dimensionc());
        cond2obs->filldown(cond, obs_cond);// set e.g. ut
        xt2obs->filldown(est._point(), obs_cond);// set e.g. ut
        ll=obs->evallogcond(dt,obs_cond);
    }
    const dirac& posterior() const {
        return est;
    }

    void set_prior(const epdf *pdf0) {
        est.set_point(pdf0->sample());
    }

    /*! Create object from the following structure
    \code
    class = "BootstrapParticle";
    parameter_pdf = configuration of bdm::epdf;      % any offspring of epdf, bdm::epdf::from_setting
    observation_pdf = configuration of bdm::epdf;    % any offspring of epdf, bdm::epdf::from_setting
    --- inherited fields ---
    bdm::BM::from_setting
    \endcode
    */
    void from_setting(const Setting &set) {
        BM::from_setting ( set );
        par = UI::build<pdf> ( set, "parameter_pdf", UI::compulsory );
        obs = UI::build<pdf> ( set, "observation_pdf", UI::compulsory );
    }

    void validate() {
        yrv = obs->_rv();
        dimy = obs->dimension();
        set_rv( par->_rv());
        set_dim( par->dimension());

        rvc = par->_rvc().subt(par->_rv().copy_t(-1));
        rvc.add(obs->_rvc()); //

        cond2obs=new datalink_part;
        cond2par=new datalink_part;
        xt2obs  =new datalink_part;
        xtm2par =new datalink_part;
        cond2obs->set_connection(obs->_rvc(), rvc);
        cond2par->set_connection(par->_rvc(), rvc);
        xt2obs->set_connection(obs->_rvc(), _rv());
        xtm2par->set_connection(par->_rvc(), _rv().copy_t(-1));

        dimc = rvc._dsize();
    };
};
UIREGISTER(BootstrapParticle);


/*!
* \brief Trivial particle filter with proposal density equal to parameter evolution model.

Posterior density is represented by a weighted empirical density (\c eEmp ).
*/

class PF : public BM {
    //! \var log_level_enums logweights
    //! all weightes will be logged

    //! \var log_level_enums logmeans
    //! means of particles will be logged
    LOG_LEVEL(PF,logweights,logmeans,logvars);

    class pf_mix: public emix_base {
        Array<BM*> &bms;
    public:
        pf_mix(vec &w0, Array<BM*> &bms0):emix_base(w0),bms(bms0) {}
        const epdf* component(const int &i)const {
            return &(bms(i)->posterior());
        }
        int no_coms() const {
            return bms.length();
        }
    };
protected:
    //!number of particles;
    int n;
    //!posterior density
    pf_mix est;
    //! weights;
    vec w;
    //! particles
    Array<BM*> particles;
    //! internal structure storing loglikelihood of predictions
    vec lls;

    //! which resampling method will be used
    RESAMPLING_METHOD resmethod;
    //! resampling threshold; in this case its meaning is minimum ratio of active particles
    //! For example, for 0.5 resampling is performed when the numebr of active aprticles drops belo 50%.
    double res_threshold;

    //! \name Options
    //!@{
    //!@}

public:
    //! \name Constructors
    //!@{
    PF ( ) : est(w,particles) { };

    void set_parameters ( int n0, double res_th0 = 0.5, RESAMPLING_METHOD rm = SYSTEMATIC ) {
        n = n0;
        res_threshold = res_th0;
        resmethod = rm;
    };
    void set_model ( const BM *particle0, const epdf *prior) {
        if (n>0) {
            particles.set_length(n);
            for (int i=0; i<n; i++) {
                particles(i) = particle0->_copy();
                particles(i)->set_prior(prior);
            }
        }
        // set values for posterior
        est.set_rv ( particle0->posterior()._rv() );
    };
    void set_statistics ( const vec w0, const epdf &epdf0 ) {
        //est.set_statistics ( w0, epdf0 );
    };
    /*    void set_statistics ( const eEmp &epdf0 ) {
            bdm_assert_debug ( epdf0._rv().equal ( par->_rv() ), "Incompatible input" );
            est = epdf0;
        };*/
    //!@}

    //! bayes compute weights of the
    virtual void bayes_weights();
    //! important part of particle filtering - decide if it is time to perform resampling
    virtual bool do_resampling() {
        double eff = 1.0 / ( w * w );
        return eff < ( res_threshold*n );
    }
    void bayes ( const vec &yt, const vec &cond );
    //!access function
    vec& _lls() {
        return lls;
    }
    //!access function
    RESAMPLING_METHOD _resmethod() const {
        return resmethod;
    }
    //! return correctly typed posterior (covariant return)
    const pf_mix& posterior() const {
        return est;
    }

    /*! configuration structure for basic PF
    \code
    particle        = bdm::BootstrapParticle;       % one bayes rule for each point in the empirical support
      - or -        = bdm::MarginalizedParticle;    % (in case of Marginalized Particle filtering
    prior           = epdf_class;                   % prior probability density on the empirical variable
    --- optional ---
    n               = 10;                           % number of particles
    resmethod       = 'systematic', or 'multinomial', or 'stratified'
                                                    % resampling method
    res_threshold   = 0.5;                          % resample when active particles drop below 50%
    \endcode
    */
    void from_setting ( const Setting &set ) {
        BM::from_setting ( set );
        UI::get ( log_level, set, "log_level", UI::optional );

        shared_ptr<BM> bm0 = UI::build<BM>(set, "particle",UI::compulsory);

        n =0;
        UI::get(n,set,"n",UI::optional);;
        if (n>0) {
            particles.set_length(n);
            for(int i=0; i<n; i++) {
                particles(i)=bm0->_copy();
            }
            w = ones(n)/n;
        }
        shared_ptr<epdf> pri = UI::build<epdf>(set,"prior");
        set_prior(pri.get());
        // set resampling method
        resmethod_from_set ( set );
        //set drv

        rvc = bm0->_rvc();
        dimc = bm0->dimensionc();
        BM::set_rv(bm0->_rv());
        yrv=bm0->_yrv();
        dimy = bm0->dimensiony();
    }

    void log_register ( bdm::logger& L, const string& prefix ) {
        BM::log_register(L,prefix);
        if (log_level[logweights]) {
            L.add_vector( log_level, logweights, RV ( particles.length()), prefix);
        }
        if (log_level[logmeans]) {
            for (int i=0; i<particles.length(); i++) {
                L.add_vector( log_level, logmeans, RV ( particles(i)->dimension() ), prefix , i);
            }
        }
        if (log_level[logvars]) {
            for (int i=0; i<particles.length(); i++) {
                L.add_vector( log_level, logvars, RV ( particles(i)->dimension() ), prefix , i);
            }
        }
    };
    void log_write ( ) const {
        BM::log_write();
        if (log_level[logweights]) {
            log_level.store( logweights, w);
        }
        if (log_level[logmeans]) {
            for (int i=0; i<particles.length(); i++) {
                log_level.store( logmeans, particles(i)->posterior().mean(), i);
            }
        }
        if (log_level[logvars]) {
            for (int i=0; i<particles.length(); i++) {
                log_level.store( logvars, particles(i)->posterior().variance(), i);
            }
        }

    }

    void set_prior(const epdf *pri) {
        const emix_base *emi=dynamic_cast<const emix_base*>(pri);
        if (emi) {
            bdm_assert(particles.length()>0, "initial particle is not assigned");
            n = emi->_w().length();
            int old_n = particles.length();
            if (n!=old_n) {
                particles.set_length(n,true);
            }
            for(int i=old_n; i<n; i++) {
                particles(i)=particles(0)->_copy();
            }

            for (int i =0; i<n; i++) {
                particles(i)->set_prior(emi->_com(i));
            }
        } else {
            // try to find "n"
            bdm_assert(n>0, "Field 'n' must be filled when prior is not of type emix");
            for (int i =0; i<n; i++) {
                particles(i)->set_prior(pri);
            }

        }
    }
    //! auxiliary function reading parameter 'resmethod' from configuration file
    void resmethod_from_set ( const Setting &set ) {
        string resmeth;
        if ( UI::get ( resmeth, set, "resmethod", UI::optional ) ) {
            if ( resmeth == "systematic" ) {
                resmethod = SYSTEMATIC;
            } else  {
                if ( resmeth == "multinomial" ) {
                    resmethod = MULTINOMIAL;
                } else {
                    if ( resmeth == "stratified" ) {
                        resmethod = STRATIFIED;
                    } else {
                        bdm_error ( "Unknown resampling method" );
                    }
                }
            }
        } else {
            resmethod = SYSTEMATIC;
        };
        if ( !UI::get ( res_threshold, set, "res_threshold", UI::optional ) ) {
            res_threshold = 0.9;
        }
        //validate();
    }

    void validate() {
        BM::validate();
        est.validate();
        bdm_assert ( n>0, "empty particle pool" );
        n = w.length();
        lls = zeros ( n );

        if ( particles(0)->_rv()._dsize() > 0 ) {
            bdm_assert (  particles(0)->_rv()._dsize() == est.dimension(), "MPF:: Mismatch of RV " +particles(0)->_rv().to_string() +
                          " of size (" +num2str(particles(0)->_rv()._dsize())+") and dimension of posterior ("+num2str(est.dimension()) + ")" );
        }
    }
    //! resample posterior density (from outside - see MPF)
    void resample ( ) {
        ivec ind = zeros_i ( n );
        bdm::resample(w,ind,resmethod);
        // copy the internals according to ind
        for (int i = 0; i < n; i++ ) {
            if ( ind ( i ) != i ) {
                delete particles(i);
                particles( i ) = particles( ind ( i ) )->_copy();
            }
            w ( i ) = 1.0 / n;
        }
    }
    //! access function
    Array<BM*>& _particles() {
        return particles;
    }
    ~PF() {
        for (int i=0; i<particles.length(); i++) {
            delete particles(i);
        }
    }

};
UIREGISTER ( PF );

/*! Marginalized particle for state-space models with unknown parameters of distribuution of residues on \f$v_t\f$.

\f{eqnarray*}{
    x_t &=& g(x_{t-1}) + v_t,\\
    y_t &\sim &fy(x_t),
    \f}

    This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.

    */
class NoiseParticleX : public MarginalizedParticleBase {
protected:
    //! function transforming xt, ut -> x_t+1
    shared_ptr<fnc> g; // pdf for non-linear part
    //! function transforming xt,ut -> yt
    shared_ptr<pdf> fy; // pdf for non-linear part

    RV rvx;
    RV rvxc;
    RV rvyc;

    //!link from condition to f
    datalink_part cond2g;
    //!link from condition to h
    datalink_part cond2fy;
    //!link from xt to f
    datalink_part x2g;
    //!link from xt to h
    datalink_part x2fy;

public:
    BM* _copy() const {
        return new NoiseParticleX(*this);
    };
    void bayes(const vec &dt, const vec &cond) {
        shared_ptr<epdf> pred_v=bm->epredictor();

        vec vt=pred_v->sample();

        //new sample
        vec &xtm=est_emp.point;
        vec g_args(g->dimensionc());
        x2g.filldown(xtm,g_args);
        cond2g.filldown(cond,g_args);
        vec xt = g->eval(g_args) + vt;
        est_emp.point=xt;

        // the vector [v_t] updates bm,
        bm->bayes(vt);

        // residue of observation
        vec fy_args(fy->dimensionc());
        x2fy.filldown(xt,fy_args);
        cond2fy.filldown(cond,fy_args);

        ll=bm->_ll() + fy->evallogcond(dt,fy_args);
    }
    void from_setting(const Setting &set) {
        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...

        g=UI::build<fnc>(set,"g",UI::compulsory);
        fy=UI::build<pdf>(set,"fy",UI::compulsory);
        UI::get(rvx,set,"rvx",UI::compulsory);
        est_emp.set_rv(rvx);

        UI::get(rvxc,set,"rvxc",UI::compulsory);
        UI::get(rvyc,set,"rvyc",UI::compulsory);

    }
    void validate() {
        MarginalizedParticleBase::validate();

        dimy = fy->dimension();
        bm->set_yrv(rvx);

        est_emp.set_rv(rvx);
        est_emp.set_dim(rvx._dsize());
        est.validate();
        //
        //check dimensions
        rvc = rvxc.subt(rvx.copy_t(-1));
        rvc.add( rvyc);
        rvc=rvc.subt(rvx);

        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
        bdm_assert(fy->dimensionc()==rvyc._dsize(),"rvyc is not described");

        bdm_assert(bm->dimensiony()==g->dimension(),
                   "Incompatible noise estimator of dimension " +
                   num2str(bm->dimensiony()) + " does not match dimension of g , " +
                   num2str(g->dimension()));

        dimc = rvc._dsize();

        //establish datalinks
        x2g.set_connection(rvxc, rvx.copy_t(-1));
        cond2g.set_connection(rvxc, rvc);

        x2fy.set_connection(rvyc, rvx);
        cond2fy.set_connection(rvyc, rvc);
    }
};
UIREGISTER(NoiseParticleX);

/*! Marginalized particle for state-space models with unknown parameters of residues distribution

\f{eqnarray*}{
    x_t &=& g(x_{t-1}) + v_t,\\
    z_t &= &h(x_{t-1}) + w_t,
    \f}

    This particle is a only a shell creating the residues calling internal estimator of their parameters. The internal estimator can be of any compatible type, e.g. ARX for Gaussian residues with unknown mean and variance.

    */
class NoiseParticle : public MarginalizedParticleBase {
protected:
    //! function transforming xt, ut -> x_t+1
    shared_ptr<fnc> g; // pdf for non-linear part
    //! function transforming xt,ut -> yt
    shared_ptr<fnc> h; // pdf for non-linear part

    RV rvx;
    RV rvxc;
    RV rvyc;

    //!link from condition to f
    datalink_part cond2g;
    //!link from condition to h
    datalink_part cond2h;
    //!link from xt to f
    datalink_part x2g;
    //!link from xt to h
    datalink_part x2h;

public:
    BM* _copy() const {
        return new NoiseParticle(*this);
    };
    void bayes(const vec &dt, const vec &cond) {
        shared_ptr<epdf> pred_vw=bm->epredictor();
        shared_ptr<epdf> pred_v = pred_vw->marginal(rvx);

        vec vt=pred_v->sample();

        //new sample
        vec &xtm=est_emp.point;
        vec g_args(g->dimensionc());
        x2g.filldown(xtm,g_args);
        cond2g.filldown(cond,g_args);
        vec xt = g->eval(g_args) + vt;
        est_emp.point=xt;

        // residue of observation
        vec h_args(h->dimensionc());
        x2h.filldown(xt,h_args);
        cond2h.filldown(cond,h_args);
        vec wt = dt-h->eval(h_args);
        // the vector [v_t,w_t] is now complete
        bm->bayes(concat(vt,wt));
        ll=bm->_ll();
    }
    void from_setting(const Setting &set) {
        MarginalizedParticleBase::from_setting(set); //reads bm, yrv,rvc, bm_rv, etc...

        UI::get(g,set,"g",UI::compulsory);
        UI::get(h,set,"h",UI::compulsory);
        UI::get(rvx,set,"rvx",UI::compulsory);
        est_emp.set_rv(rvx);

        UI::get(rvxc,set,"rvxc",UI::compulsory);
        UI::get(rvyc,set,"rvyc",UI::compulsory);

    }
    void validate() {
        MarginalizedParticleBase::validate();

        dimy = h->dimension();
        bm->set_yrv(concat(rvx,yrv));

        est_emp.set_rv(rvx);
        est_emp.set_dim(rvx._dsize());
        est.validate();
        //
        //check dimensions
        rvc = rvxc.subt(rvx.copy_t(-1));
        rvc.add( rvyc);
        rvc=rvc.subt(rvx);

        bdm_assert(g->dimension()==rvx._dsize(),"rvx is not described");
        bdm_assert(g->dimensionc()==rvxc._dsize(),"rvxc is not described");
        bdm_assert(h->dimension()==rvyc._dsize(),"rvyc is not described");

        bdm_assert(bm->dimensiony()==g->dimension()+h->dimension(),
                   "Incompatible noise estimator of dimension " +
                   num2str(bm->dimensiony()) + " does not match dimension of g and h, " +
                   num2str(g->dimension())+" and "+ num2str(h->dimension()) );

        dimc = rvc._dsize();

        //establish datalinks
        x2g.set_connection(rvxc, rvx.copy_t(-1));
        cond2g.set_connection(rvxc, rvc);

        x2h.set_connection(rvyc, rvx);
        cond2h.set_connection(rvyc, rvc);
    }
};
UIREGISTER(NoiseParticle);


}
#endif // KF_H


