/*!
  \file
  \brief Bayesian Filtering for linear Gaussian models (Kalman Filter) and extensions
  \author Vaclav Smidl.

  -----------------------------------
  BDM++ - C++ library for Bayesian Decision Making under Uncertainty

  Using IT++ for numerical operations
  -----------------------------------
*/

#ifndef KF_H
#define KF_H


#include "../math/functions.h"
#include "../stat/exp_family.h"
#include "../math/chmat.h"
#include "../base/user_info.h"
//#include <../applications/pmsm/simulator_zdenek/ekf_example/pmsm_mod.h>

namespace bdm {

/*!
 * \brief Basic elements of linear state-space model

Parameter evolution model:\f[ x_{t+1} = A x_{t} + B u_t + Q^{1/2} e_t \f]
Observation model: \f[ y_t = C x_{t} + C u_t + R^{1/2} w_t. \f]
Where $e_t$ and $w_t$ are mutually independent vectors of Normal(0,1)-distributed disturbances.
 */
template<class sq_T>
class StateSpace {
protected:
	//! Matrix A
	mat A;
	//! Matrix B
	mat B;
	//! Matrix C
	mat C;
	//! Matrix D
	mat D;
	//! Matrix Q in square-root form
	sq_T Q;
	//! Matrix R in square-root form
	sq_T R;
public:
	StateSpace() :  A(), B(), C(), D(), Q(), R() {}
	//!copy constructor
	StateSpace ( const StateSpace<sq_T> &S0 ) :  A ( S0.A ), B ( S0.B ), C ( S0.C ), D ( S0.D ), Q ( S0.Q ), R ( S0.R ) {}
	//! set all matrix parameters
	void set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 );
	//! validation
	void validate();
	//! not virtual in this case
	void from_setting ( const Setting &set ) {
		UI::get ( A, set, "A", UI::compulsory );
		UI::get ( B, set, "B", UI::compulsory );
		UI::get ( C, set, "C", UI::compulsory );
		UI::get ( D, set, "D", UI::compulsory );
		mat Qtm, Rtm; // full matrices
		if ( !UI::get ( Qtm, set, "Q", UI::optional ) ) {
			vec dq;
			UI::get ( dq, set, "dQ", UI::compulsory );
			Qtm = diag ( dq );
		}
		if ( !UI::get ( Rtm, set, "R", UI::optional ) ) {
			vec dr;
			UI::get ( dr, set, "dQ", UI::compulsory );
			Rtm = diag ( dr );
		}
		R = Rtm; // automatic conversion to square-root form
		Q = Qtm;

		validate();
	}
	//! access function
	const mat& _A() const {
		return A;
	}
	//! access function
	const mat& _B() const {
		return B;
	}
	//! access function
	const mat& _C() const {
		return C;
	}
	//! access function
	const mat& _D() const {
		return D;
	}
	//! access function
	const sq_T& _Q() const {
		return Q;
	}
	//! access function
	const sq_T& _R() const {
		return R;
	}
};

//! Common abstract base for Kalman filters
template<class sq_T>
class Kalman: public BM, public StateSpace<sq_T> {
protected:
	//! id of output
	RV yrv;
	//! Kalman gain
	mat  _K;
	//!posterior
	enorm<sq_T> est;
	//!marginal on data f(y|y)
	enorm<sq_T>  fy;
public:
	Kalman<sq_T>() : BM(), StateSpace<sq_T>(), yrv(), _K(),  est() {}
	//! Copy constructor
	Kalman<sq_T> ( const Kalman<sq_T> &K0 ) : BM ( K0 ), StateSpace<sq_T> ( K0 ), yrv ( K0.yrv ), _K ( K0._K ),  est ( K0.est ), fy ( K0.fy ) {}
	//!set statistics of the posterior
	void set_statistics ( const vec &mu0, const mat &P0 ) {
		est.set_parameters ( mu0, P0 );
	};
	//!set statistics of the posterior
	void set_statistics ( const vec &mu0, const sq_T &P0 ) {
		est.set_parameters ( mu0, P0 );
	};
	//! return correctly typed posterior (covariant return)
	const enorm<sq_T>& posterior() const {
		return est;
	}
	//! load basic elements of Kalman from structure
	void from_setting ( const Setting &set ) {
		StateSpace<sq_T>::from_setting ( set );

		mat P0;
		vec mu0;
		UI::get ( mu0, set, "mu0", UI::optional );
		UI::get ( P0, set,  "P0", UI::optional );
		set_statistics ( mu0, P0 );
		// Initial values
		shared_ptr<RV> yrv_ptr = UI::build<RV>( set, "yrv", UI::optional );
		if( !yrv_ptr ) yrv_ptr = new RV();
		shared_ptr<RV> rvc_ptr = UI::build<RV>( set, "urv", UI::optional );
		if( !rvc_ptr ) rvc_ptr = new RV();
		set_yrv ( concat ( *yrv_ptr, *rvc_ptr ) );
	}
	//! validate object
	void validate() {
		StateSpace<sq_T>::validate();
		dimy = this->C.rows();
		dimc = this->B.cols();
		set_dim ( this->A.rows() );

		bdm_assert ( est.dimension(), "Statistics and model parameters mismatch" );
	}

};
/*!
* \brief Basic Kalman filter with full matrices
*/

class KalmanFull : public Kalman<fsqmat> {
public:
	//! For EKFfull;
	KalmanFull() : Kalman<fsqmat>() {};
	//! Here dt = [yt;ut] of appropriate dimensions
	void bayes ( const vec &yt, const vec &cond = empty_vec );

	virtual KalmanFull* _copy() const {
		KalmanFull* K = new KalmanFull;
		K->set_parameters ( A, B, C, D, Q, R );
		K->set_statistics ( est._mu(), est._R() );
		return K;
	}
};
UIREGISTER ( KalmanFull );


/*! \brief Kalman filter in square root form

Trivial example:
\include kalman_simple.cpp

Complete constructor:
*/
class KalmanCh : public Kalman<chmat> {
protected:
	//! @{ \name Internal storage - needs initialize()
	//! pre array (triangular matrix)
	mat preA;
	//! post array (triangular matrix)
	mat postA;
	//!@}
public:
	//! copy constructor
	virtual KalmanCh* _copy() const {
		KalmanCh* K = new KalmanCh;
		K->set_parameters ( A, B, C, D, Q, R );
		K->set_statistics ( est._mu(), est._R() );
		K->validate();
		return K;
	}
	//! set parameters for adapt from Kalman
	void set_parameters ( const mat &A0, const mat &B0, const mat &C0, const mat &D0, const chmat &Q0, const chmat &R0 );
	//! initialize internal parametetrs
	void initialize();

	/*!\brief  Here dt = [yt;ut] of appropriate dimensions

	The following equality hold::\f[
	\left[\begin{array}{cc}
	R^{0.5}\\
	P_{t|t-1}^{0.5}C' & P_{t|t-1}^{0.5}CA'\\
	& Q^{0.5}\end{array}\right]<\mathrm{orth.oper.}>=\left[\begin{array}{cc}
	R_{y}^{0.5} & KA'\\
	& P_{t+1|t}^{0.5}\\
	\\\end{array}\right]\f]

	Thus this object evaluates only predictors! Not filtering densities.
	*/
	void bayes ( const vec &yt, const vec &cond = empty_vec );

	void from_setting ( const Setting &set ) {
		Kalman<chmat>::from_setting ( set );
		validate();
	}
	void validate() {
		Kalman<chmat>::validate();
		initialize();
	}
};
UIREGISTER ( KalmanCh );

/*!
\brief Extended Kalman Filter in full matrices

An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
*/
class EKFfull : public KalmanFull {
protected:
	//! Internal Model f(x,u)
	shared_ptr<diffbifn> pfxu;

	//! Observation Model h(x,u)
	shared_ptr<diffbifn> phxu;

public:
	//! Default constructor
	EKFfull ();

	//! Set nonlinear functions for mean values and covariance matrices.
	void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const mat Q0, const mat R0 );

	//! Here dt = [yt;ut] of appropriate dimensions
	void bayes ( const vec &yt, const vec &cond = empty_vec );
	//! set estimates
	void set_statistics ( const vec &mu0, const mat &P0 ) {
		est.set_parameters ( mu0, P0 );
	};
	//! access function
	const mat _R() {
		return est._R().to_mat();
	}
	void from_setting ( const Setting &set ) {
		BM::from_setting ( set );
		shared_ptr<diffbifn> IM = UI::build<diffbifn> ( set, "IM", UI::compulsory );
		shared_ptr<diffbifn> OM = UI::build<diffbifn> ( set, "OM", UI::compulsory );

		//statistics
		int dim = IM->dimension();
		vec mu0;
		if ( !UI::get ( mu0, set, "mu0" ) )
			mu0 = zeros ( dim );

		mat P0;
		vec dP0;
		if ( UI::get ( dP0, set, "dP0" ) )
			P0 = diag ( dP0 );
		else if ( !UI::get ( P0, set, "P0" ) )
			P0 = eye ( dim );

		set_statistics ( mu0, P0 );

		//parameters
		vec dQ, dR;
		UI::get ( dQ, set, "dQ", UI::compulsory );
		UI::get ( dR, set, "dR", UI::compulsory );
		set_parameters ( IM, OM, diag ( dQ ), diag ( dR ) );

// 			pfxu = UI::build<diffbifn>(set, "IM", UI::compulsory);
// 			phxu = UI::build<diffbifn>(set, "OM", UI::compulsory);
//
// 			mat R0;
// 			UI::get(R0, set, "R",UI::compulsory);
// 			mat Q0;
// 			UI::get(Q0, set, "Q",UI::compulsory);
//
//
// 			mat P0; vec mu0;
// 			UI::get(mu0, set, "mu0", UI::optional);
// 			UI::get(P0, set,  "P0", UI::optional);
// 			set_statistics(mu0,P0);
// 			// Initial values
// 			UI::get (yrv, set, "yrv", UI::optional);
// 			UI::get (urv, set, "urv", UI::optional);
// 			set_drv(concat(yrv,urv));
//
// 			// setup StateSpace
// 			pfxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), A,true);
// 			phxu->dfdu_cond(mu0, zeros(pfxu->_dimu()), C,true);
//
	}

	void validate() {
		KalmanFull::validate();

		// check stats and IM and OM
	}
};
UIREGISTER ( EKFfull );


/*!
\brief Extended Kalman Filter in Square root

An approximation of the exact Bayesian filter with Gaussian noices and non-linear evolutions of their mean.
*/

class EKFCh : public KalmanCh {
protected:
	//! Internal Model f(x,u)
	shared_ptr<diffbifn> pfxu;

	//! Observation Model h(x,u)
	shared_ptr<diffbifn> phxu;
public:
	//! copy constructor duplicated - calls different set_parameters
	EKFCh* _copy() const {
		return new EKFCh(*this);
	}
	//! Set nonlinear functions for mean values and covariance matrices.
	void set_parameters ( const shared_ptr<diffbifn> &pfxu, const shared_ptr<diffbifn> &phxu, const chmat Q0, const chmat R0 );

	//! Here dt = [yt;ut] of appropriate dimensions
	void bayes ( const vec &yt, const vec &cond = empty_vec );

	void from_setting ( const Setting &set );

	void validate() {};
	// TODO dodelat void to_setting( Setting &set ) const;

};

UIREGISTER ( EKFCh );
SHAREDPTR ( EKFCh );


//////// INstance

/*! \brief (Switching) Multiple Model
The model runs several models in parallel and evaluates thier weights (fittness).

The statistics of the resulting density are merged using (geometric?) combination.

The next step is performed with the new statistics for all models.
*/
class MultiModel: public BM {
protected:
	//! List of models between which we switch
	Array<EKFCh*> Models;
	//! vector of model weights
	vec w;
	//! cache of model lls
	vec _lls;
	//! type of switching policy [1=maximum,2=...]
	int policy;
	//! internal statistics
	enorm<chmat> est;
public:
	//! set internal parameters
	void set_parameters ( Array<EKFCh*> A, int pol0 = 1 ) {
		Models = A;//TODO: test if evalll is set
		w.set_length ( A.length() );
		_lls.set_length ( A.length() );
		policy = pol0;

		est.set_rv ( RV ( "MM", A ( 0 )->posterior().dimension(), 0 ) );
		est.set_parameters ( A ( 0 )->posterior().mean(), A ( 0 )->posterior()._R() );
	}
	void bayes ( const vec &yt, const vec &cond = empty_vec ) {
		int n = Models.length();
		int i;
		for ( i = 0; i < n; i++ ) {
			Models ( i )->bayes ( yt );
			_lls ( i ) = Models ( i )->_ll();
		}
		double mlls = max ( _lls );
		w = exp ( _lls - mlls );
		w /= sum ( w ); //normalization
		//set statistics
		switch ( policy ) {
		case 1: {
			int mi = max_index ( w );
			const enorm<chmat> &st = Models ( mi )->posterior() ;
			est.set_parameters ( st.mean(), st._R() );
		}
		break;
		default:
			bdm_error ( "unknown policy" );
		}
		// copy result to all models
		for ( i = 0; i < n; i++ ) {
			Models ( i )->set_statistics ( est.mean(), est._R() );
		}
	}
	//! return correctly typed posterior (covariant return)
	const enorm<chmat>& posterior() const {
		return est;
	}

	void from_setting ( const Setting &set );

};
UIREGISTER ( MultiModel );
SHAREDPTR ( MultiModel );

//! conversion of outer ARX model (mlnorm) to state space model
/*!
The model is constructed as:
\f[ x_{t+1} = Ax_t + B u_t + R^{1/2} e_t, y_t=Cx_t+Du_t + R^{1/2}w_t, \f]
For example, for:
Using Frobenius form, see [].

For easier use in the future, indices theta_in_A and theta_in_C are set. TODO - explain
*/
//template<class sq_T>
class StateCanonical: public StateSpace<fsqmat> {
protected:
	//! remember connection from theta ->A
	datalink_part th2A;
	//! remember connection from theta ->C
	datalink_part th2C;
	//! remember connection from theta ->D
	datalink_part th2D;
	//!cached first row of A
	vec A1row;
	//!cached first row of C
	vec C1row;
	//!cached first row of D
	vec D1row;

public:
	//! set up this object to match given mlnorm
	void connect_mlnorm ( const mlnorm<fsqmat> &ml );

	//! fast function to update parameters from ml - not checked for compatibility!!
	void update_from ( const mlnorm<fsqmat> &ml );
};
/*!
State-Space representation of multivariate autoregressive model.
The original model:
\f[ y_t = \theta [\ldots y_{t-k}, \ldots u_{t-l}, \ldots z_{t-m}]' + \Sigma^{-1/2} e_t \f]
where \f$ k,l,m \f$ are maximum delayes of corresponding variables in the regressor.

The transformed state is:
\f[ x_t = [y_{t} \ldots y_{t-k-1}, u_{t} \ldots u_{t-l-1}, z_{t} \ldots z_{t-m-1}]\f]

The state accumulates all delayed values starting from time \f$ t \f$ .


*/
class StateFromARX: public StateSpace<chmat> {
protected:
	//! remember connection from theta ->A
	datalink_part th2A;
	//! remember connection from theta ->B
	datalink_part th2B;
	//!function adds n diagonal elements from given starting point r,c
	void diagonal_part ( mat &A, int r, int c, int n ) {
		for ( int i = 0; i < n; i++ ) {
			A ( r, c ) = 1.0;
			r++;
			c++;
		}
	};
	//! similar to ARX.have_constant
	bool have_constant;
public:
	//! set up this object to match given mlnorm
	//! Note that state-space and common mpdf use different meaning of \f$ _t \f$ in \f$ u_t \f$.
	//!While mlnorm typically assumes that \f$ u_t \rightarrow y_t \f$ in state space it is \f$ u_{t-1} \rightarrow y_t \f$
	//! For consequences in notation of internal variable xt see arx2statespace_notes.lyx.
	void connect_mlnorm ( const mlnorm<chmat> &ml, RV &xrv, RV &urv );

	//! fast function to update parameters from ml - not checked for compatibility!!
	void update_from ( const mlnorm<chmat> &ml );

	//! access function
	bool _have_constant() const {
		return have_constant;
	}
};

/////////// INSTANTIATION

template<class sq_T>
void StateSpace<sq_T>::set_parameters ( const mat &A0, const  mat &B0, const  mat &C0, const  mat &D0, const  sq_T &Q0, const sq_T &R0 ) {

	A = A0;
	B = B0;
	C = C0;
	D = D0;
	R = R0;
	Q = Q0;
	validate();
}

template<class sq_T>
void StateSpace<sq_T>::validate() {
	bdm_assert ( A.cols() == A.rows(), "KalmanFull: A is not square" );
	bdm_assert ( B.rows() == A.rows(), "KalmanFull: B is not compatible" );
	bdm_assert ( C.cols() == A.rows(), "KalmanFull: C is not compatible" );
	bdm_assert ( ( D.rows() == C.rows() ) && ( D.cols() == B.cols() ), "KalmanFull: D is not compatible" );
	bdm_assert ( ( Q.cols() == A.rows() ) && ( Q.rows() == A.rows() ), "KalmanFull: Q is not compatible" );
	bdm_assert ( ( R.cols() == C.rows() ) && ( R.rows() == C.rows() ), "KalmanFull: R is not compatible" );
}

}
#endif // KF_H


