#include <itpp/itbase.h>
#include "libDC.h"

using namespace itpp;

using std::endl;

//! Auxiliary function dydr; dyadic reduction
void dydr( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx );

//! Auxiliary function ltuinv; inversion of a triangular matrix;
//TODO can be done via: dtrtri.f from lapack
mat ltuinv( const mat &L );



ldmat::ldmat( const mat &exL, const vec &exD ) {
	D = exD;
	L = exL;
	dim = exD.length();
}

ldmat::ldmat() {
	vec D ;
	mat L;
	dim = 0;
}

ldmat::ldmat( const mat V ) {
//TODO check if correct!! Based on heuristic observation of lu()

	dim = V.cols();
	mat F;
	vec D0;
	it_assert_debug( dim == V.rows(),"ldmat::ldmat matrix V is not square!" );
	
	//decompose V in cholesky
	D0 = ones(dim);
	

	using std::cout;
	cout<<V;
	chol(V,F);
	// L and D will be allocated by ldform()
	this->ldform(F,D0);	
}

void ldmat::opupdt( const vec &v,  double w ) {
	int dim = D.length();
	double kr;
	vec r = v;
	//beware! it is potentionally dangerous, if ITpp change _behaviour of _data()!
	double *Lraw = L._data();
	double *Draw = D._data();
	double *rraw = r._data();

	it_assert_debug( v.length() == dim, "LD::ldupdt vector v is not compatible with this ld." );

	for ( int i = dim - 1; i >= 0; i-- ) {
		dydr( rraw, Lraw + i, &w, Draw + i, rraw + i, 0, i, &kr, 1, dim );
	}
}

std::ostream &operator<< ( std::ostream &os,  ldmat &ld ) {
	os << "L:" << ld.L << endl;
	os << "D:" << ld.D << endl;
}

mat ldmat::to_mat() {
	int dim = D.length();
	mat V( dim, dim );
	double sum;
	int r, c, cc;

	for ( r = 0;r < dim;r++ ) { //row cycle
		for ( c = r;c < dim;c++ ) {
			//column cycle, using symmetricity => c=r!
			sum = 0.0;
			for ( cc = c;cc < dim;cc++ ) { //cycle over the remaining part of the vector
				sum += L( cc, r ) * D( cc ) * L( cc, c );
				//here L(cc,r) = L(r,cc)';
			}
			V( r, c ) = sum;
			// symmetricity
			if ( r != c ) {V( c, r ) = sum;};
		}
	}
	mat V2 = L.transpose()*diag( D )*L;
	return V2;
}


void ldmat::add( const ldmat &ld2, double w ) {
	int dim = D.length();

	it_assert_debug( ld2.D.length() == dim, "LD.add() incompatible sizes of LDs;" );

	//Fixme can be done more efficiently either via dydr or ldform
	for ( int r = 0; r < dim; r++ ) {
		// Add columns of ld2.L' (i.e. rows of ld2.L) as dyads weighted by ld2.D
		this->opupdt( ld2.L.get_row( r ), w*ld2.D( r ) );
	}
}

void ldmat::clear(){L.clear(); for ( int i=0;i<L.cols();i++ ){L( i,i )=1;}; D.clear();}

void ldmat::inv( ldmat &Inv ) {
	int dim = D.length();
	Inv.clear();   //Inv = zero in LD
	mat U = ltuinv( L );

	//Fixme can be done more efficiently either via dydr or ldform
	for ( int r = 0; r < dim; r++ ) {
		// Add columns of U as dyads weighted by 1/D
		Inv.opupdt( U.get_col( r ), 1.0 / D( r ) );
	}
}

void ldmat::mult_sym( const mat &C, bool trans ) {

//TODO better

	it_assert_debug( C.cols()==L.cols(), "ldmat::mult_sym wrong input argument" );
	mat Ct=C;

	if ( trans==false ) { // return C*this*C'
		Ct *= this->to_mat();
		Ct *= C.transpose();
	} else {	// return C'*this*C
		Ct = C.transpose();
		Ct *= this->to_mat();
		Ct *= C;
	}

	ldmat Lnew=ldmat( Ct );
	L = Lnew.L;
	D = Lnew.D;
}

void ldmat::mult_sym( const mat &C, ldmat &U, bool trans ) {

//TODO better

//TODO input test

	mat Ct=C;

	if ( trans==false ) { // return C*this*C'
		Ct *= U.to_mat();
		Ct *= C.transpose();
	} else {	// return C'*this*C
		Ct = C.transpose();
		Ct *= U.to_mat();
		Ct *= C;
	}

	ldmat Lnew=ldmat( Ct );
	L = Lnew.L;
	D = Lnew.D;
}

double ldmat::logdet() {
	double ldet = 0.0;
	int i;
// sum logarithms of diagobal elements
	for ( i=0; i<D.length(); i++ ){ldet+=log( D( i ) );};
}

double ldmat::qform( vec &v ) {
	double x = 0.0, sum;
	int i,j;

	for ( i=0; i<D.length(); i++ ) { //rows of L
		sum = 0.0;
		for ( j=0; j<=i; j++ ){sum+=L( i,j )*v( j );}
		x +=D( i )*sum*sum;
	};
	return x;
}

ldmat& ldmat::operator *= ( double x ) {
	int i;
	for ( i=0;i<D.length();i++ ){D( i )*=x;};
}

vec ldmat::sqrt_mult( vec &x ) {
	int i,j;
	vec res( dim );
	double sum;
	for ( i=0;i<dim;i++ ) {//for each element of result
		res( i ) = 0.0;
		for ( j=i;j<dim;j++ ) {//sum D(j)*L(:,i).*x
			res( i ) += sqrt( D( j ) )*L( j,i )*x( j );
		}
	}
	vec res2 = L.transpose()*diag( sqrt( D ) )*x;
	return res2;
}

void ldmat::ldform( mat &A,vec &D0 ) {
	int m = A.rows();
	int n = A.cols();
	int mn = (m<n) ? m :n ;

	it_assert_debug( A.cols()==dim,"ldmat::ldform A is not compatible" );
	it_assert_debug( D0.length()==A.rows(),"ldmat::ldform Vector D must have the length as row count of A" );

	L=concat_vertical( zeros( n,n ), diag( sqrt( D0 ) )*A );
	D=zeros( n+m );
	
	//unnecessary big L and D will be made smaller at the end of file
	
	vec w=zeros( n );
	vec v=zeros(n);
	double sum, beta, pom;

	int cc=0;
	int i=n; // +1 in .m
	int ii,j,jj;
	while (( i> ( n-mn+1-cc ) )&&( i>1 ) ) {
		i--;
		sum = 0.0;
		v.set_size( m+i-( n-cc ) ); //prepare v
		for ( ii=n-cc;ii<m+i;i++ ) {
			sum+= L( ii,i )*L( ii,i );
			v( ii )=L( ii,i ); //assign v
		}
		if ( L( m+i,i )==0 ) {
			beta = sqrt( sum );
		} else {
			beta = L( m+i,i )+sign( L( m+i,i ) )*sqrt( sum ) ;
		}
		if ( std::fabs( beta )<eps ) {
			cc++;
			L.set_row( n+1-cc, L.get_row( m+i ) );
			L.set_row( m+i,0 );
			D( m+i )=0; L( m+i,i )=1;
			L.set_submatrix( n+1-cc,m+i,i,i,0 );
			continue;
		}

		sum-=v( v.length()-1 )*v( v.length()-1 ); //
		sum/=beta*beta;
		sum++;

		v/=beta;
		v( v.length()-1 )=1;

		pom=-2/sum;
		for ( j=i;i>=0;i-- ) {
			w( j )=0.0;
			for ( ii=n-cc;ii<m+i;ii++ ) {
				w( j )+= v( ii )*L( ii,j );
			}
			w( j )*=pom;
		}

		for ( ii=n-cc;ii<m+i;ii++ ) {
			for ( jj=0;jj<i-1;jj++ ) {
				L( ii,jj )+= v( ii )*w( jj );
			}
		}
		for ( ii=n-cc;ii<m+i;ii++ ) {
			L( ii,i )= 0;
		}
		L( m+i,i )+=w( i );

		D( m+i )=L( m+i,i )*L( m+i,i );
		for ( ii=0;ii<i;ii++ ) {
			L( m+i,ii )/=L( m+i,i );
		}
	}
	if ( i>0 ) {
		for ( ii=0;ii<i-1;ii++ ) {
			jj = D.length()-1-n+ii;
			L.set_row(jj,0);
			L(jj,jj)=1;
		}
	}

	//cut-out L and D;
	L.del_rows(0,m-1);
	D.del(0,m-1);
}

//////// Auxiliary Functions

mat ltuinv( const mat &L ) {
	int dim = L.cols();
	mat Il = eye( dim );
	int i, j, k, m;
	double s;

//Fixme blind transcription of ltuinv.m
	for ( k = 1; k < ( dim );k++ ) {
		for ( i = 0; i < ( dim - k );i++ ) {
			j = i + k; //change in .m 1+1=2, here 0+0+1=1
			s = L( j, i );
			for ( m = i + 1; m < ( j - 1 ); m++ ) {
				s += L( m, i ) * Il( j, m );
			}
			Il( j, i ) = -s;
		}
	}

	return Il;
}

void dydr( double * r, double *f, double *Dr, double *Df, double *R, int jl, int jh, double *kr, int m, int mx )
/********************************************************************

   dydr = dyadic reduction, performs transformation of sum of
          2 dyads r*Dr*r'+ f*Df*f' so that the element of r pointed
          by R is zeroed. This version allows Dr to be NEGATIVE. Hence the name negdydr or dydr_withneg.

   Parameters :
     r ... pointer to reduced dyad
     f ... pointer to reducing dyad
     Dr .. pointer to the weight of reduced dyad
     Df .. pointer to the weight of reducing dyad
     R ... pointer to the element of r, which is to be reduced to
           zero; the corresponding element of f is assumed to be 1.
     jl .. lower index of the range within which the dyads are
           modified
     ju .. upper index of the range within which the dyads are
           modified
     kr .. pointer to the coefficient used in the transformation of r
           rnew = r + kr*f
     m  .. number of rows of modified matrix (part of which is r)
  Remark : Constant mzero means machine zero and should be modified
           according to the precision of particular machine

                                                 V. Peterka 17-7-89

  Added:
     mx .. number of rows of modified matrix (part of which is f)  -PN

********************************************************************/
{
	int j, jm;
	double kD, r0;
	double mzero = 2.2e-16;
	double threshold = 1e-4;

	if ( fabs( *Dr ) < mzero ) *Dr = 0;
	r0 = *R;
	*R = 0.0;
	kD = *Df;
	*kr = r0 * *Dr;
	*Df = kD + r0 * ( *kr );
	if ( *Df > mzero ) {
		kD /= *Df;
		*kr /= *Df;
	} else {
		kD = 1.0;
		*kr = 0.0;
		if ( *Df < -threshold ) {
			it_warning( "Problem in dydr: subraction of dyad results in negative definitness. Likely mistake in calling function." );
		}
		*Df = 0.0;
	}
	*Dr *= kD;
	jm = mx * jl;
	for ( j = m * jl; j < m*jh; j += m ) {
		r[j] -=  r0 * f[jm];
		f[jm] += *kr * r[j];
		jm += mx;
	}
}


