#include <itpp/itbase.h>
#include "libDC.h"

using namespace itpp;

using std::endl;

//! Auxiliary function dydr; dyadic reduction
void dydr( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx );

//! Auxiliary function ltuinv; inversion of a triangular matrix;
//TODO can be done via: dtrtri.f from lapack
mat ltuinv( const mat &L );



ldmat::ldmat( const mat &exL, const vec &exD ) {
	D = exD;
	L = exL;
}

ldmat::ldmat() {
	vec D ;
	mat L;
}

ldmat::ldmat( const mat V ) {
//TODO check if correct!! Based on heuristic observation of lu()

	int dim = V.cols();
	it_assert_debug( dim == V.rows(),"ldmat::ldmat matrix V is not square!" );

	mat U( dim,dim );

	L = V; //Allocate space for L
	ivec p = ivec( dim ); //not clear why?

	lu( V,L,U,p );

//Now, if V is symmetric, L is what we seek and D is on diagonal of U
	D = diag( U );

//check if V was symmetric
//TODO How? norm of L-U'?
//it_assert_debug();
}

void ldmat::opupdt( const vec &v,  double w ) {
	int dim = D.length();
	double kr;
	vec r = v;
	//beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
	double *Lraw = L._data();
	double *Draw = D._data();
	double *rraw = r._data();

	it_assert_debug( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );

	for ( int i = dim - 1; i >= 0; i-- ) {
		dydr( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
	}
}

std::ostream &operator<< ( std::ostream &os,  sqmat &sq ) {
	os << sq.to_mat() << endl;
}

mat ldmat::to_mat() {
	int dim = D.length();
	mat V( dim, dim );
	double sum;
	int r, c, cc;

	for ( r = 0;r < dim;r++ ) { //row cycle
		for ( c = r;c < dim;c++ ) {
			//column cycle, using symmetricity => c=r!
			sum = 0.0;
			for ( cc = c;cc < dim;cc++ ) { //cycle over the remaining part of the vector
				sum += L( cc, r ) * D( cc ) * L( cc, c );
				//here L(cc,r) = L(r,cc)';
			}
			V( r, c ) = sum;
			// symmetricity
			if ( r != c ) {V( c, r ) = sum;};
		}
	}
	return V;
}


void ldmat::add( const ldmat &ld2, double w ) {
	int dim = D.length();

	it_assert_debug( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );

	//Fixme can be done more efficiently either via dydr or ldform
	for ( int r = 0; r < dim; r++ ) {
		// Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
		this->opupdt( ld2.L.get_row( r ), w*ld2.D( r ) );
	}
}

void ldmat::clear(){L.clear(); for ( int i=0;i<L.cols();i++ ){L( i,i )=1;}; D.clear();}

void ldmat::inv( ldmat &Inv ) {
	int dim = D.length();
	Inv.clear();   //Inv = zero in LD
	mat U = ltuinv( L );

	//Fixme can be done more efficiently either via dydr or ldform
	for ( int r = 0; r < dim; r++ ) {
		// Add columns of U as dyads weighted by 1/D
		Inv.opupdt( U.get_col( r ), 1.0 / D( r ) );
	}
}

void ldmat::mult_qform( const mat &C, bool trans ) {

//TODO better

	it_assert_debug( C.cols()==L.cols(), "ldmat::mult_qform wrong input argument" );
	mat Ct=C;

	if ( trans==false ) { // return C*this*C'
		Ct *= this->to_mat();
		Ct *= C.transpose();
	} else {	// return C'*this*C
		Ct = C.transpose();
		Ct *= this->to_mat();
		Ct *= C;
	}

	ldmat Lnew=ldmat( Ct );
	L = Lnew.L;
	D = Lnew.D;
}

double ldmat::logdet() {
	double ldet = 0.0;
	int i;
// sum logarithms of diagobal elements
	for ( i=0; i<D.length(); i++ ){ldet+=log( D( i ) );};
}

double ldmat::qform( vec &v ) {
	double x = 0.0, sum;
	int i,j;

	for ( i=0; i<D.length(); i++ ) { //rows of L
		sum = 0.0;
		for ( j=0; j<=i; j++ ){sum+=L( i,j )*v( j );}
		x +=D( i )*sum*sum;
	};
	return x;
}

ldmat& ldmat::operator *= (double x){
int i;
for(i=0;i<D.length();i++){D(i)*=x;};
}


//////// Auxiliary Functions

mat ltuinv( const mat &L ) {
	int dim = L.cols();
	mat Il = eye( dim );
	int i, j, k, m;
	double s;

//Fixme blind transcription of ltuinv.m
	for ( k = 1; k < ( dim );k++ ) {
		for ( i = 0; i < ( dim - k );i++ ) {
			j = i + k; //change in .m 1+1=2, here 0+0+1=1
			s = L( j, i );
			for ( m = i + 1; m < ( j - 1 ); m++ ) {
				s += L( m, i ) * Il( j, m );
			}
			Il( j, i ) = -s;
		}
	}

	return Il;
}

void dydr( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
/********************************************************************

   dydr = dyadic reduction, performs transformation of sum of
          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.

   Parameters :
     r ... pointer to reduced dyad
     f ... pointer to reducing dyad
     Dr .. pointer to the weight of reduced dyad
     Df .. pointer to the weight of reducing dyad
     R ... pointer to the element of r, which is to be reduced to
           zero; the corresponding element of f is assumed to be 1.
     jl .. lower index of the range within which the dyads are
           modified
     ju .. upper index of the range within which the dyads are
           modified
     kr .. pointer to the coefficient used in the transformation of r
           rnew = r + kr*f
     m  .. number of rows of modified matrix (part of which is r)
  Remark : Constant mzero means machine zero and should be modified
           according to the precision of particular machine

                                                 V. Peterka 17-7-89

  Added:
     mx .. number of rows of modified matrix (part of which is f)  -PN

********************************************************************/
{
	int j, jm;
	double kD, r0;
	double mzero = 2.2e-16;
	double threshold = 1e-4;

	if ( fabs( *Dr ) < mzero ) *Dr = 0;
	r0 = *R;
	*R = 0.0;
	kD = *Df;
	*kr = r0 * *Dr;
	*Df = kD + r0 * ( *kr );
	if ( *Df > mzero ) {
		kD /= *Df;
		*kr /= *Df;
	} else {
		kD = 1.0;
		*kr = 0.0;
		if ( *Df < -threshold ) it_warning( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
		*Df = 0.0;
	}
	*Dr *= kD;
	jm = mx * jl;
	for ( j = m * jl; j < m*jh; j += m ) {
		r[j] -=  r0 * f[jm];
		f[jm] += *kr * r[j];
		jm += mx;
	}
}


