
#include "chmat.h"

//using std::endl;

namespace bdm {

void chmat::add ( const chmat &A2, double w ) {
    bdm_assert_debug ( dim == A2.dim, "Matrices of unequal dimension" );
    mat pre = concat_vertical ( Ch, sqrt ( w ) * A2.Ch );
    mat post = zeros ( pre.rows(), pre.cols() );
    if ( !qr ( pre, post ) ) {
        bdm_warning ( "Unstable QR in chmat add" );
    }
    Ch = post ( 0, dim - 1, 0, dim - 1 );
};

void chmat::opupdt ( const vec &v, double w ) {
//TODO see cholupdt in lhotse
    mat Z;
    mat R;
    mat V ( 1, v.length() );
    V.set_row ( 0, v*sqrt ( w ) );
    Z = concat_vertical ( Ch, V );
    qr ( Z, R );
    Ch = R ( 0, Ch.rows() - 1, 0, Ch.cols() - 1 );
}

mat chmat::to_mat() const {
    mat F = Ch.T() * Ch;
    return F;
}

void chmat::mult_sym ( const mat &C ) {
    bdm_assert_debug ( C.rows() == dim, "Wrong dimension of U" );
    if ( !qr ( Ch*C.T(), Ch ) ) {
        bdm_warning ( "QR unstable in chmat mult_sym" );
    }
}

void chmat::mult_sym ( const mat &C , chmat &U ) const {
    bdm_assert_debug ( C.rows() == U.dim, "Wrong dimension of U" );
    if ( !qr ( Ch*C.T(), U.Ch ) ) {
        bdm_warning ( "QR unstable in chmat mult_sym" );
    }
}

void chmat::mult_sym_t ( const mat &C ) {
    bdm_assert_debug ( C.cols() == dim, "Wrong dimension of U" );
    if ( !qr ( Ch*C, Ch ) ) {
        bdm_warning ( "QR unstable in chmat mult_sym" );
    }
}

void chmat::mult_sym_t ( const mat &C, chmat &U ) const {
    bdm_assert_debug ( C.cols() == U.dim, "Wrong dimension of U" );
    if ( !qr ( Ch*C, U.Ch ) ) {
        bdm_warning ( "QR unstable in chmat mult_sym" );
    }
}

double chmat::logdet() const {
    double ldet = 0.0;
    int i;
    //sum of logs of (possibly negative!) diagonal entries
    for ( i = 0; i < Ch.rows(); i++ ) {
        ldet += log ( std::fabs ( Ch ( i, i ) ) );
    }
    return 2*ldet; //compensate for Ch being sqrt()
}

//TODO can be done more efficiently using BLAS, see triangular matrices
vec chmat::sqrt_mult ( const vec &v ) const {
    vec pom;
    pom = Ch.T() * v;
    return pom;
}

double chmat::qform ( const vec &v ) const {
    vec pom;
    pom = Ch * v;
    return pom*pom;
}

double chmat::invqform ( const vec &v ) const {
    vec pom ( v.length() );
    forward_substitution ( Ch.T(), v, pom );
    return pom*pom;
}

void chmat::clear() {
    Ch.clear();
}

}
