#define BDMLIB
#include "../mat_checks.h"
#include "design/lq_ctrl.h"

using namespace bdm;

TEST ( LQG_test ) {
    LQG reg;
    shared_ptr<StateSpace<chmat> > stsp = new StateSpace<chmat>;
    // 2 x 1 x 1
    stsp-> set_parameters ( eye ( 2 ), ones ( 2, 1 ), ones ( 1, 2 ), ones ( 1, 1 ), /* Q,R */ eye ( 2 ), eye ( 1 ) );
    reg.set_system ( stsp ); // A, B, C
    reg.set_control_parameters ( eye ( 1 ), eye ( 1 ),  vec_1 ( 1.0 ), 6 ); //Qy, Qu, horizon
    reg.validate();

    reg.redesign();
    double reg_apply = reg.ctrlaction ( "0.5, 1.1", "0.0" ) ( 0 ); /*convert vec to double*/
    CHECK_CLOSE ( reg_apply, -0.248528137234392, 0.0001 );
}

TEST ( to_state_test ) {
    mlnorm<fsqmat> ml;
    mat A = "1.1, 2.3";
    ml.set_parameters ( A, vec_1 ( 1.3 ), eye ( 1 ) );
    RV yr = RV ( "y", 1 );
    RV ur = RV ( "u", 1 );
    ml.set_rv ( yr );
    yr.t_plus ( -1 );
    ml.set_rvc ( concat ( yr, ur ) );

    shared_ptr<StateCanonical > Stsp = new StateCanonical;
    Stsp->connect_mlnorm ( ml );

    /* results from
    [A,B,C,D]=tf2ss([2.3 0],[1 -1.1])
    */
    CHECK_CLOSE_EX ( Stsp->_A().get_row ( 0 ), vec ( "1.1" ), 0.0001 );
    CHECK_CLOSE_EX ( Stsp->_C().get_row ( 0 ), vec ( "2.53" ), 0.0001 );
    CHECK_CLOSE_EX ( Stsp->_D().get_row ( 0 ), vec ( "2.30" ), 0.0001 );
}

TEST ( to_state_arx_test ) {
    mlnorm<chmat> ml;
    mat A = "1.1, 2.3, 3.4";
    ml.set_parameters ( A, vec_1 ( 1.3 ), eye ( 1 ) );
    RV yr = RV ( "y", 1 );
    RV ur = RV ( "u", 1 );
    ml.set_rv ( yr );
    ml.set_rvc ( concat ( yr.copy_t ( -1 ), concat ( ur, ur.copy_t ( -1 ) ) ) );

    shared_ptr<StateFromARX> Stsp = new StateFromARX;
    RV xrv;
    RV urv;
    Stsp->connect_mlnorm ( ml, xrv, urv );

    /* results from
    [A,B,C,D]=tf2ss([2.3 0],[1 -1.1])
    */
    CHECK_CLOSE_EX ( Stsp->_A().get_row ( 0 ), vec ( "1.1, 3.4, 1.3" ), 0.0001 );
    CHECK_CLOSE_EX ( Stsp->_B().get_row ( 0 ), vec ( "2.3" ), 0.0001 );
    CHECK_CLOSE_EX ( Stsp->_C().get_row ( 0 ), vec ( "1, 0, 0" ), 0.0001 );
}

TEST ( arx_LQG_test ) {
    mlnorm<chmat> ml;
    mat A = "1.81, -.81, .00468, .00438";
    ml.set_parameters ( A, vec_1 ( 0.0 ), 0.00001*eye ( 1 ) );
    RV yr = RV ( "y", 1 );
    RV ur = RV ( "u", 1 );
    RV rgr = yr.copy_t ( -1 );
    rgr.add ( yr.copy_t ( -2 ) );
    rgr.add ( yr.copy_t ( -2 ) );
    rgr.add ( ur.copy_t ( -1 ) );
    rgr.add ( ur );

    ml.set_rv ( yr );
    ml.set_rvc ( rgr );
    ml.validate();

    shared_ptr<StateFromARX> Stsp = new StateFromARX;
    RV xrv;
    RV urv;
    Stsp->connect_mlnorm ( ml, xrv, urv );

    LQG L;
    L.set_system ( Stsp );
    L.set_control_parameters ( eye ( 1 ), sqrt ( 1.0 / 1000 ) *eye ( 1 ), vec_1 ( 0.0 ), 100 );
    L.validate();

    L.redesign();
    cout << L.to_string() << endl;
}

TEST (quadratic_test){
  /*quadraticfn qf;
  qf.Q = chmat(2);
  qf.Q._Ch() = mat("1 -1 0; 0 0 0; 0 0 0");
  CHECK_CLOSE_EX(qf.eval(vec("1 2")), vec_1(1.0), 0.0000000000000001);
  
  LQG_universal lq;
  lq.Losses = Array<quadraticfn>(1);
  lq.Losses(0) = quadraticfn();
  lq.Losses(0).Q._Ch() = mat("1 -1 0; 0 0 0; 0 0 0");
  lq.Losses(0).rv = RV("{u up }");
  
  lq.Models = Array<linfnEx>(1);
  lq.Models(0) = linfnEx(mat("1"),vec("1"));
  lq.Models(0).rv = RV("{x }");
  
  lq.rv = RV("u",1);
  
  lq.redesign();
  CHECK_CLOSE_EX(lq.ctrlaction(vec("1,0")), vec("1.24, -5.6"), 0.0000001);*/
}

TEST (lqguniversal_test){
  //test of universal LQG controller
  LQG_universal lq;
  lq.rv = RV("u", 2, 0);  
  lq.set_rvc(RV("x", 2, 0));
  lq.horizon = 10;

  /*
		model:      x = Ax + Bu			time: 0..horizon
		loss:       l = x'Q'Qx + u'R'Ru		time: 0..horizon-1
		final loss: l = x'S'Sx			time: horizon

		dim:	x: 2
				u: 2

		A = [	2	-1	 ]
			[	0	0.5	 ]
		
		B = [	1		-0.1	]	
			[	-0.2	2		]

		Q = [	5	0	]
			[	0	1	]

		R = [	0.01	0	 ]
			[	0		0.1	 ]

		S = Q
  */

  //mat mA("2 -1;0 0.5"); 
  //mat mB("1 -0.1;-0.2 2"); 
  //mat mQ("5 0;0 1"); 
  //mat mR("0.01 0;0 0.1"); 
  //mat mS = mQ;

  ////starting configuration
  //vec x0("6 3");

  //uniform random generator
  Uniform_RNG urng;
  double maxmult = 10;
  vec tmpdiag;
  
  mat mA;   
	urng.sample_matrix(2, 2, mA);
	mA *= maxmult;
  mat mB; 
	urng.sample_matrix(2, 2, mB);
	mB *= maxmult;
  mat mQ(2, 2); 
	urng.sample_vector(2, tmpdiag);
	tmpdiag *= maxmult;
	mQ.zeros();
  	mQ.set(0, 0, tmpdiag.get(0));
	mQ.set(1, 1, tmpdiag.get(1));
  mat mR(2, 2); 
	urng.sample_vector(2, tmpdiag);
	tmpdiag *= maxmult;
	mR.zeros();
	mR.set(0, 0, tmpdiag.get(0));
	mR.set(1, 1, tmpdiag.get(1));
  mat mS(2, 2); 
	urng.sample_vector(2, tmpdiag);
	tmpdiag *= maxmult;
	mS.zeros();
	mS.set(0, 0, tmpdiag.get(0));
	mS.set(1, 1, tmpdiag.get(1));

  //starting configuration
  vec x0;
	urng.sample_vector(2, x0);
	x0 *= maxmult;
 
	/*cout << "configuration:" << endl 
		<< "mA:" << endl << mA << endl
		<< "mB:" << endl << mB << endl
		<< "mQ:" << endl << mQ << endl
		<< "mR:" << endl << mR << endl
		<< "mS:" << endl << mS << endl
		<< "x0:" << endl << x0 << endl;*/

  //model
  Array<linfnEx> model(2);
  //model Ax part
  model(0).A = mA;
  model(0).B = vec("0 0");
  model(0).rv = RV("x", 2, 0);
  model(0).rv_ret = RV("x", 2, 1);
  //model Bu part
  model(1).A = mB;
  model(1).B = vec("0 0");
  model(1).rv = RV("u", 2, 0);
  model(1).rv_ret = RV("x", 2, 1);	
  //setup
  lq.Models = model;

  //loss
  Array<quadraticfn> loss(2);
  //loss x'Qx part
  loss(0).Q.setCh(mQ);
  loss(0).rv = RV("x", 2, 0);
  //loss u'Ru part
  loss(1).Q.setCh(mR);
  loss(1).rv = RV("u", 2, 0);
  //setup
  lq.Losses = loss;

  //finalloss setup
  lq.finalLoss.Q.setCh(mS);
  lq.finalLoss.rv = RV("x", 2, 1);
  
  //default L
  //cout << "default L matrix:" << endl << lq.getL() << endl;

  //produce last control matrix L
  lq.redesign();
  
  //verification via Riccati LQG version
  mat mK = mS;
  mat mL = - inv(mR + mB.transpose() * mK * mB) * mB.transpose() * mK * mA;

  //cout << "L matrix LQG_universal:" << endl << lq.getL() << endl << 
	  //"L matrix LQG Riccati:" << endl << mL << endl;

  //checking L matrix (universal vs Riccati), tolerance is high, but L is not main criterion
  //more important is reached loss compared in the next part
  double tolerr = 1;//0.01; //0.01 OK x 0.001 NO OK

  //check last time L matrix
  CHECK_CLOSE_EX(lq.getL().get_cols(0,1), mL, tolerr);
  
  mat oldK;
  int i;

  //produce next control matrix L
  for(i = 0; i < lq.horizon - 1; i++) {
	  lq.redesign();
	  oldK = mK;
	  mK = mA.transpose() * (oldK - oldK * mB * inv(mR + mB.transpose() * oldK * mB) * mB.transpose() * oldK) * mA + mQ;
	  mL = - inv(mR + mB.transpose() * mK * mB) * mB.transpose() * mK * mA;

	  //cout << "L matrix LQG_universal:" << endl << lq.getL() << endl << 
		  //"L matrix LQG Riccati:" << endl << mL << endl;

	  //check other times L matrix
	  CHECK_CLOSE_EX(lq.getL().get_cols(0,1), mL, tolerr);
  }

 //check losses of LQG control - compare LQG_universal and Riccati version, no noise
    
  //loss of LQG_universal
  /*double*/vec loss_uni("0");

  //setup
  vec x = x0;
  vec xold = x;
  vec u;
  //vec tmploss;

  //iteration
  for(i = 0; i < lq.horizon - 1; i++){
	u = lq.getL().get_cols(0,1) * xold;
	x = mA * xold + mB * u;
	/*tmploss*/loss_uni = x.transpose() * mQ * x + u.transpose() * mR * u;
	//loss_uni += tmploss.get(0);
	xold = x;
  }
  /*tmploss*/ loss_uni = x.transpose() * mS * x;
  //loss_uni += tmploss.get(0);

  //loss of LQG Riccati version
  /*double*/ vec loss_rct("0");

  //setup
  x = x0;
  xold = x;

  //iteration
  for(i = 0; i < lq.horizon - 1; i++){
	u = mL * xold;
	x = mA * xold + mB * u;
	/*tmploss*/loss_rct = x.transpose() * mQ * x + u.transpose() * mR * u;
	//loss_rct += tmploss.get(0);
	xold = x;
  }
  /*tmploss*/loss_rct = x.transpose() * mS * x;
  //loss_rct += tmploss.get(0);

  //cout << "Loss LQG_universal: " << loss_uni << " vs Loss LQG Riccati: " << loss_rct << endl;
  CHECK_CLOSE_EX(loss_uni, loss_rct, 0.0001);
}