#include "ctrlbase.h"

namespace bdm {

const bool STRICT_RV = true; //empty RV NOT allowed

//! extended class representing function \f$f(x) = Ax+B\f$
class linfnEx: public linfn {
  public:
    //! Identification of returned value \f$f(x)\f$
    RV rv_ret;
	//!default constructor
    linfnEx ( ) : linfn() { };
    linfnEx ( const mat &A0, const vec &B0 ) : linfn(A0, B0) { };
};

  //! Universal LQG controller
class LQG_universal : public Controller{
public:
	//! Controller inputs
		//! loss function
		Array<quadraticfn> Losses;
		//! loss in final time
		quadraticfn finalLoss;
		//! model of evolutin
		Array<linfnEx> Models;
//		RV model_rv_ret;

		//! control law rv is public member in Controller class	 
		//! input data rvc is protected member in Controller class
		 
		//! control horizon
		int horizon;

	//! Constructor
		LQG_universal() {
			horizon = 0;
			curtime = -1;
		}
		
  protected:
    //! control law: rv = L [rvc, 1]
    mat L;
    
    //! Matrix pre_qr
    mat pre_qr;
    
    //! Matrix post_qr
    mat post_qr;

	//! time+1 optimized loss to be added to current one
	mat tolC;

	int curtime;
    
public:
    //! function redesigning the control strategy
    virtual void redesign() {
		if (curtime == -1) curtime = horizon;
				
		if (curtime > 0){
			curtime--;
			generateLmat(curtime);
		}
		//cout << "time " << curtime << endl;
		//report time 0 reached - LQG designing complete
		//if(curtime == 0) cout << "time 0 reached" << endl;
	}
    //! returns designed control action
    virtual vec ctrlaction ( const vec &cond ) const {
        return L * concat(cond, 1.0);
    }

    void from_setting ( const Setting &set ) {
      UI::get(Losses, set, "losses",UI::compulsory);
      UI::get(Models, set, "models",UI::compulsory);
    }
    //! access function
    const RV& _rv() {
        return rv;
    }
    //! access function
    const RV& _rvc() {
        return rvc;
    }

	void set_rvc(RV _rvc) {rvc = _rvc;}

    //! register this controller with given datasource under name "name"
    virtual void log_register ( logger &L, const string &prefix ) { }
    //! write requested values into the logger
    virtual void log_write ( ) const { }

	//! access debug function
	mat getL(){ return L; }

	void resetTime() { curtime = -1; }

	//! check if model and losses is correct and consistent
	virtual void validate(){
		/*
			RV:findself hleda cela rv jako vektory, pri nenalezeni je -1
			RV:dataind hleda datove slozky, tedy indexy v poli skalaru, pri nenalezeni vynecha
		*/
		// (0) nonempty
		bdm_assert((Models.size() > 0), "VALIDATION FAILED! Models array empty.");
		bdm_assert((Losses.size() > 0), "VALIDATION FAILED! Losses array empty.");
		if( (Models.size() <= 0) || (Losses.size() <= 0) ) return;

		// (1) test Models array rv - acceptable rv is only part/composition of LQG_universal::rv, LQG_universal::rvc and const 1
		RV accept_total;
		accept_total = rv;
		accept_total.add(rvc);
		accept_total.add(RV("1", 1, 0));
		
		int i, j;
		ivec finding1;

		for(i = 0; i < Models.length(); i++){
			finding1 = Models(i).rv.findself(accept_total); 

			bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");

			for(j = 0; j < finding1.size(); j++){				
				bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Provided input RV for some Models function is unknown, forbidden or recursive.");							
				if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
			}			
		}		
		
		//NOT!!! (2) test Models array rv_ret - each array element's rv_ret must be unique (except const 1)
		//RV unique_rv_ret;
		//unique_rv_ret = Models(0).rv_ret;

		//for(i = 1; i < Models.length(); i++){
		//	finding1 = Models(i).rv_ret.findself(unique_rv_ret);

		//	for(j = 0; j < finding1.size(); j++){								
		//		if(Models(i).rv_ret.name(j) == "1") continue; // except const 1

		//		bdm_assert((finding1(j) == (-1) ), "VALIDATION FAILED! Models functions result RV (rv_ret) must be unique.");
		//		if(finding1(j) != (-1) ) return; //rv_ret element not unique
		//	}

		//	unique_rv_ret.add(Models(i).rv_ret);
		//}		
		
		// (3) test Losses array - acceptable rv is only part/composition of LQG_universal::rv, LQG_universal::rvc, Models rv_ret and const 1
		for(i = 0; i < Models.length(); i++) accept_total.add(Models(i).rv_ret); //old accept_total from (1) + all rv_ret from Models
		
		for(i = 0; i < Losses.length(); i++){
			finding1 = Losses(i).rv.findself(accept_total);

			bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");

			for(j = 0; j < finding1.size(); j++){
				bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Unacceptable RV used in some Losses function.");
				if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
			}
		}	

		// same for finalLoss
		finding1 = finalLoss.rv.findself(accept_total);

		bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");

		for(j = 0; j < finding1.size(); j++){
			bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Unacceptable RV used in finalLoss function.");
			if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
		}
	}

private:
	RV trs_crv;

	//! compute complete RV from all of RVs used in Losses array
	/*RV getCompleteRV() {
		RV cRv; //complete RV

		//cRv has form [rv, "other", 1]
		cRv = rv; //add rv

		//add "other"
		for(int i = 0; i < Losses.size(); i++)
			cRv.add(Losses(i).rv);

		cRv.add(rvOne); //add 1

		return cRv;
	}*/

	mat getMatRow (quadraticfn sourceQfn){ //returns row of matrixes crated from quadratic function
		
		mat tmpMatRow; //tmp variable for row of matrixes to be returned	
		tmpMatRow.set_size(sourceQfn.Q.rows(), trs_crv.countsize()); //(rows, cols)
		tmpMatRow.zeros();

		//set data in tmpMatRow - other times then current replace using model
		RV tmpQrv = sourceQfn.rv;
//cout << "Qrv " << tmpQrv << endl << "total crv " << trs_crv << endl;
		ivec j_vec(1);
		vec copycol;
		ivec copysource;
		for(int j = 0; j < tmpQrv.length(); j++){			
			j_vec(0) = j;
//cout << "matrow\n" << tmpMatRow << endl;

			if( /*(tmpQrv.time(j) == 0) &&*/ (sum(tmpQrv(j_vec).findself(trs_crv)) > (-1)) ) {//sum is only formal, summed vector is in fact scalar
				//jth element of tmpQrv is also element of trs_crv with a proper time
//cout << "copy" << endl;
				ivec copytarget = (tmpQrv(j_vec)).dataind(trs_crv); //get target column position in tmpMatRow
				ivec copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q
				if(copytarget.size() != copysource.size()) {return mat(0); /*error*/}				
				for(int k = 0; k < copysource.size(); k++){
					copycol = sourceQfn.Q._Ch().get_col(copysource(k));
					copycol += tmpMatRow.get_col(copytarget(k));
					tmpMatRow.set_col(copytarget(k), copycol);
				}					
			}
			else {
//cout << "USING MODEL on " << tmpQrv(j,j) << endl;
//cout << "model" << endl;
				//jth tmpQrv element is not in trs_crv -> using Model to teplace it
				// = (tmpQrv(j_vec)).findself(tmpQrv); //get source column position in Losses(i).Q

				//int selectedModel = -1;

				//int k;
				////find first usable replacement in Model				
				//for(k = 0; k < Models.size(); k++){							
				//	if( sum((tmpQrv(j_vec)).findself(Models(k).rv_ret)) > (-1) ){
				//	//TODO is tmpQrv(j) in kth Models RV
				//		//if( (Models(k)).rv_ret.findself_ids(trs_crv) ){//???????????????????
				//		// is kth Models rv_ret subset of trs_crv

				//		selectedModel = k;
				//		break;

				//		//}
				//	}
				//}
				
				//!model is model_rv_ret = sum(Array<linfn>) = sum( A1*rv + B1 + ... + An*rv + Bn)
				
				//if(selectedModel == -1) {cout << "NO MODEL" << endl;return mat("0");}//ERROR - inconsistent model data;				
			//!!TODO!!if(NOT((tmpQrv(j_vec)).findself(model_rv_ret))) {cout << "NO MODEL" << endl;return mat("0");}//ERROR - inconsistent model data;				
				//use kth Model to convert tmpQrv memeber to trs_crv memeber

				//get submatrix from Q which represents jth tmpQrv data
				copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q				
				mat copysubmat;
				copysubmat.set_size(sourceQfn.Q.rows(), copysource.size()); //(rows, cols)
				copysubmat.zeros();
				vec copycol;
				
				int k;
				for(k = 0; k < copysource.size(); k++){
					copycol = sourceQfn.Q._Ch().get_col(copysource(k));
					copysubmat.set_col(k, copycol);
				}

				//check every Models element if it is a proper substitution: tmpQrv(j_vec) memeber of rv_ret
				for(k = 0; k < Models.size(); k++){
					if( sum((tmpQrv(j_vec)).findself(Models(k).rv_ret)) > (-1) ){ //formal sum, find usable model
						//check if model is correct
						ivec check = (Models(k).rv).findself(trs_crv);
						if(sum(check) <= -check.size()){
							bdm_assert (false , "Incorrect Model: Unusable Models element!" );						 
							continue;
						}

						//create transformed submatrix
						mat transsubmat = copysubmat * ((Models(k)).A);

						//put them on a right place in tmpQrv
						ivec copytarget = (Models(k)).rv.dataind(trs_crv); //get target column position in tmpMatRow
												
						//copy transsubmat into tmpMatRow with adding to current one
						//	tmpMatRow(new) = tmpMatRow(old) + transsubmat /all in proper indices
						int l;
						for(l = 0; l < copytarget.size(); l++){					
							copycol = tmpMatRow.get_col(copytarget(l));					
							copycol += transsubmat.get_col(l);								
							tmpMatRow.set_col(copytarget(l), copycol);								
						}

						//if linear fnc constant element vec B is nonzero
						vec constElB = (Models(k)).B;				
						if(prod(constElB) != 0){
							//copy transformed constant vec into last (1's) col in tmpMatRow
							int lastcol = tmpMatRow.cols() - 1;
							copycol = tmpMatRow.get_col(lastcol);
							copycol += (copysubmat * ((Models(k)).B));
							tmpMatRow.set_col(lastcol, copycol);
						}
					}

				}

				
			}
		}

		return tmpMatRow;
	}

	//! create first(0) or other (1) pre_qr matrix 
	void build_pre_qr(bool next) {
		int i;
		//used fake quadratic function from tolC matrix
		quadraticfn fakeQfn;

		//RV pretrs_crv = getCompleteRV(); // crv before transformation based on Losses array

		//set proper size of pre_qr matrix
		int rows = 0;
		for(i = 0; i < Losses.size(); i++)
			rows += Losses(i).Q.rows();
		if(!next) rows += finalLoss.Q.rows();
		else{
			//used fake quadratic function from tolC matrix
			//setup fakeQfn
			fakeQfn.Q.setCh(tolC);
			RV fakeM1;
			fakeM1 = rvc;
			fakeM1.add(RV("1", 1, 0));
			fakeM1.t_plus(1); //RV in time t+1 => necessary use of Model to get RV in time t
			fakeM1.set_time((RV("1", 1, 1).findself(fakeM1))(0) , 0);
//cout << fakeM1 << endl;
			fakeQfn.rv = fakeM1;
			
			rows += fakeQfn.Q.rows();
		}
//cout << "buildpreqr trscrv: " << trs_crv << " of size " << trs_crv.countsize() << endl;
		pre_qr.set_size(rows, trs_crv.countsize()); //(rows, cols)
		pre_qr.zeros();

		//fill pre_qr matrix for each Losses quadraticfn		
		int rowIndex = 0;
		mat tmpMatRow;
		for(i = 0; i < Losses.size(); i++) {
			rows = Losses(i).Q.rows();
			
			//compute row matrix and insert it on proper place in pre_qr
			tmpMatRow = getMatRow(Losses(i));
//cout << "tmpMatRow no " << i << endl << tmpMatRow << endl;
			//copy tmpMatRow in pre_qr

		/*cout << "submatrix ( " << rowIndex << ", " << 
			(rowIndex + rows - 1) << ", 0, " << (trs_crv.countsize() - 1) << ")" << endl;
		cout << "seting with submatrix of rows " << tmpMatRow.rows() << " and cols " << tmpMatRow.cols() <<
			"and data " << endl << tmpMatRow << endl;*/

			pre_qr.set_submatrix(rowIndex, (rowIndex + rows - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)
			rowIndex += rows;
		}

		if(!next) {			
			tmpMatRow = getMatRow(finalLoss);			
			pre_qr.set_submatrix(rowIndex, (rowIndex + finalLoss.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)		
		}
		else { //next
			//based on tolC but time must be shifted by one - all implemented in getMatRow method
				
			//get matrix row via getMatRow method
			//cout << "XXXXXXXXXXXXXXX" << endl;
			tmpMatRow = getMatRow(fakeQfn);
			//cout << tmpMatRow << endl;
		/*cout << "submatrix ( " << rowIndex << ", " << 
			(rowIndex + fakeQfn.Q.rows() - 1) << ", 0, " << (trs_crv.countsize() - 1) << ")" << endl;
		cout << "seting with submatrix of rows " << tmpMatRow.rows() << " and cols " << tmpMatRow.cols() <<
			"and data " << endl << tmpMatRow << endl;*/
			pre_qr.set_submatrix(rowIndex, (rowIndex + fakeQfn.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)		
		//cout << "NEXT 3" << endl;
		}
//cout << "last tmpMatRow  " << endl << tmpMatRow << endl;
	}	

	mat get_qr_submatrix(int submatidx) {
	/*
		|rv||rvc||1|

		AAAABBBBBBBB
		 AAABBBBBBBB
		  AABBBBBBBB
		   ABBBBBBBB
		    CCCCCCCC
			 CCCCCCC
			  CCCCCC
			   CCCCC
			    CCCC
				 CCC
				  CC
				   C
	*/
	/*!
		submatidx | get_submatrix
		----------|--------------
		    0     |      A
			1     |		 B
			2+	  |		 C
	*/
		int sizeA = rv.countsize();
		int colsB = post_qr.cols() -  sizeA;
		//  rowsB = sizeA; 
		//  colsC = colsB;
		//not required whole C - it is triangular
		//=> NOT int rowsC = post_qr.rows() - sizeA;
		//=> int sizeC = colsB;

		mat qr_submat;
		
		if(submatidx == 0) qr_submat		= post_qr.get(0,		(sizeA - 1),			0,		(sizeA - 1));  //(int r1, int r2, int c1, int c2)
		else if(submatidx == 1) qr_submat	= post_qr.get(0,		(sizeA - 1),			sizeA,	(post_qr.cols() - 1));
		else {
			if(post_qr.cols() > post_qr.rows()) { //extend post_qr matrix to be at least square
				post_qr.set_size(post_qr.cols(), post_qr.cols(), true);				
			}
			
			qr_submat						= post_qr.get(sizeA,	(sizeA + colsB - 1),	sizeA,	(post_qr.cols() - 1));
			
		}
	
		return qr_submat;
	}

	void generateLmat(int timestep){
		//! control strategy matrix L is based on loss in time:
		//!		time = horizon			loss = finalLoss
		//!		time = horizon - 1		loss = sum(Losses)(time) + finalLoss
		//!		time = horizon - k > 1	loss = sum(Losses)(time) + tolC time+1 loss

		trs_crv = rv; //transformed crv only in proper times, form [rv, rvc, 1]
		trs_crv.add(rvc);
		trs_crv.add(RV("1", 1, 0));
			
		//!first time, time = horizon - 1
		if(timestep == (horizon-1))		
			build_pre_qr(0);
		
		//!other times		
		else
			build_pre_qr(1);

		mat tmpQ;		
		qr(pre_qr, tmpQ, post_qr);
//cout << "preQR " << pre_qr << endl << "postQR" << post_qr << endl;
		mat qrA = get_qr_submatrix(0);		
		mat qrB = get_qr_submatrix(1);
		mat qrC = get_qr_submatrix(2);
//cout << "A " << qrA << "\n B " << qrB << "\n C " << qrC << endl;

		L = - inv(qrA)*qrB; ///////// INVERSE OF TRIANGLE MATRIX! better?
		tolC = qrC;
	}

};

class LQG_recedinghorizon : public LQG_universal {
protected:
	//!total_curtime is curtime for total_horizon
	int total_curtime;
public:
	//! LQG_universal::horizon means shorter receding horizon for designing control strategy
	//! total_horizon is longer total horizon
	int total_horizon;
	
	//!constructor
	LQG_recedinghorizon() : LQG_universal() {
			total_horizon = 0;
			total_curtime = 0;
		}

	virtual void redesign() {
		if (total_curtime < total_horizon){
			for(int i = 0; i < horizon - 1; i++) LQG_universal::redesign();
			total_curtime++;
		}		
	}

	virtual vec ctrlaction ( const vec &cond ) const {
        //return L * concat(cond, 1.0);
        return empty_vec;
    }

    //! register this controller with given datasource under name "name"
    virtual void log_register ( logger &L, const string &prefix ) { }
    //! write requested values into the logger
    virtual void log_write ( ) const { }	
};

} // namespace
