
#include "bdmbase.h"

//! Space of basic BDM structures
namespace bdm {

const int RV::BUFFER_STEP = 1;

Array<string> RV::NAMES ( RV::BUFFER_STEP );

ivec RV::SIZES ( RV::BUFFER_STEP );

RV::str2int_map RV::MAP;

void RV::clear_all() {
	MAP.clear();
	SIZES.clear();
	NAMES = Array<string> ( BUFFER_STEP );
}

string RV::show_all(){
  ostringstream os;
  for(str2int_map::const_iterator iter=MAP.begin(); iter!=MAP.end(); iter++){
	  os << "key: " << iter->first << " val: " << iter->second <<endl;
  }
  return os.str();
};

int RV::init ( const string &name, int size ) {
	//Refer
	int id;
	str2int_map::const_iterator iter = MAP.find ( name );
	if ( iter == MAP.end() || name.length()==0) { //add new RV
		id = MAP.size() + 1;
		//debug
		/*		{
					cout << endl;
					str2int_map::const_iterator iter = MAP.begin();
					for(str2int_map::const_iterator iter=MAP.begin(); iter!=MAP.end(); iter++){
						cout << "key: " << iter->first << " val: " << iter->second <<endl;
					}
				}*/

		MAP.insert ( make_pair ( name, id ) ); //add new rv
		if ( id >= NAMES.length() ) {
			NAMES.set_length ( id + BUFFER_STEP, true );
			SIZES.set_length ( id + BUFFER_STEP, true );
		}
		NAMES ( id ) = name;
		SIZES ( id ) = size;
		bdm_assert(size>0, "RV "+ name +" does not exists. Default size (-1) can not be assigned ");
	} else {
		id = iter->second;
		if (size>0 && name.length()>0){
			bdm_assert ( SIZES ( id ) == size, "RV " + name + " of size " + num2str(SIZES(id)) + " exists, requested size " + num2str(size) + "can not be assigned" );
		}
	}
	return id;
};

int RV::countsize() const {
	int tmp = 0;
	for ( int i = 0; i < len; i++ ) {
		tmp += SIZES ( ids ( i ) );
	}
	return tmp;
}

ivec RV::cumsizes() const {
	ivec szs ( len );
	int tmp = 0;
	for ( int i = 0; i < len; i++ ) {
		tmp += SIZES ( ids ( i ) );
		szs ( i ) = tmp;
	}
	return szs;
}

void RV::init ( const Array<std::string> &in_names, const ivec &in_sizes, const ivec &in_times ) {
	len = in_names.length();
	bdm_assert ( in_names.length() == in_times.length(), "check \"times\" " );
	bdm_assert ( in_names.length() == in_sizes.length(), "check \"sizes\" " );

	times.set_length ( len );
	ids.set_length ( len );
	int id;
	for ( int i = 0; i < len; i++ ) {
		id = init ( in_names ( i ), in_sizes ( i ) );
		ids ( i ) = id;
	}
	times = in_times;
	dsize = countsize();
}

RV::RV ( string name, int sz, int tm ) {
	Array<string> A ( 1 );
	A ( 0 ) = name;
	init ( A, vec_1 ( sz ), vec_1 ( tm ) );
}

bool RV::add ( const RV &rv2 ) {
	// TODO
	if ( rv2.len > 0 ) { //rv2 is nonempty
		ivec ind = rv2.findself ( *this ); //should be -1 all the time
		ivec index = itpp::find ( ind == -1 );

		if ( index.length() < rv2.len ) { //conflict
			ids = concat ( ids, rv2.ids ( index ) );
			times = concat ( times, rv2.times ( index ) );
		} else {
			ids = concat ( ids, rv2.ids );
			times = concat ( times, rv2.times );
		}
		len = ids.length();
		dsize = countsize();
		return ( index.length() == rv2.len ); //conflict or not
	} else { //rv2 is empty
		return true; // no conflict
	}
};

RV RV::subselect ( const ivec &ind ) const {
	RV ret;
	ret.ids = ids ( ind );
	ret.times = times ( ind );
	ret.len = ind.length();
	ret.dsize = ret.countsize();
	return ret;
}

RV RV::operator() ( int di1, int di2 ) const {
	ivec sz = cumsizes();
	int i1 = 0;
	while ( sz ( i1 ) < di1 ) i1++;
	int i2 = i1;
	while ( sz ( i2 ) < di2 ) i2++;
	return subselect ( linspace ( i1, i2 ) );
}

void RV::t_plus ( int delta ) {
	times += delta;
}

bool RV::equal ( const RV &rv2 ) const {
	return ( ids == rv2.ids ) && ( times == rv2.times );
}

shared_ptr<pdf> epdf::condition ( const RV &rv ) const {
	bdm_warning ( "Not implemented" );
	return shared_ptr<pdf>();
}

shared_ptr<epdf> epdf::marginal ( const RV &rv ) const {
	bdm_warning ( "Not implemented" );
	return shared_ptr<epdf>();
}

mat epdf::sample_mat ( int N ) const {
	mat X = zeros ( dim, N );
	for ( int i = 0; i < N; i++ ) X.set_col ( i, this->sample() );
	return X;
}

vec epdf::evallog_mat ( const mat &Val ) const {
	vec x ( Val.cols() );
	for ( int i = 0; i < Val.cols(); i++ ) {
		x ( i ) = evallog ( Val.get_col ( i ) );
	}

	return x;
}

vec epdf::evallog_mat ( const Array<vec> &Avec ) const {
	vec x ( Avec.size() );
	for ( int i = 0; i < Avec.size(); i++ ) {
		x ( i ) = evallog ( Avec ( i ) );
	}

	return x;
}

mat pdf::samplecond_mat ( const vec &cond, int N ) {
	mat M ( dimension(), N );
	for ( int i = 0; i < N; i++ ) {
		M.set_col ( i, samplecond ( cond ) );
	}

	return M;
}

void pdf::from_setting ( const Setting &set ) {
	shared_ptr<RV> r = UI::build<RV> ( set, "rv", UI::optional );
	if ( r ) {
		set_rv ( *r );
	}

	r = UI::build<RV> ( set, "rvc", UI::optional );
	if ( r ) {
		set_rvc ( *r );
	}
}

void datalink::set_connection ( const RV &rv, const RV &rv_up ) {
	downsize = rv._dsize();
	upsize = rv_up._dsize();
	v2v_up = rv.dataind ( rv_up );
	bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
}

void datalink::set_connection ( int ds, int us, const ivec &upind ) {
	downsize = ds;
	upsize = us;
	v2v_up = upind;
	bdm_assert_debug ( v2v_up.length() == downsize, "rv is not fully in rv_up" );
}

void datalink_part::set_connection ( const RV &rv, const RV &rv_up ) {
	rv.dataind ( rv_up, v2v_down, v2v_up );
	downsize = v2v_down.length();
	upsize = v2v_up.length();
}

void datalink_m2e::set_connection ( const RV &rv, const RV &rvc, const RV &rv_up ) {
	datalink::set_connection ( rv, rv_up );
	condsize = rvc._dsize();
	//establish v2c connection
	rvc.dataind ( rv_up, v2c_lo, v2c_up );
}

vec datalink_m2e::get_cond ( const vec &val_up ) {
	vec tmp ( condsize );
	set_subvector ( tmp, v2c_lo, val_up ( v2c_up ) );
	return tmp;
}

void datalink_m2e::pushup_cond ( vec &val_up, const vec &val, const vec &cond ) {
	bdm_assert_debug ( downsize == val.length(), "Wrong val" );
	bdm_assert_debug ( upsize == val_up.length(), "Wrong val_up" );
	set_subvector ( val_up, v2v_up, val );
	set_subvector ( val_up, v2c_up, cond );
}

std::ostream &operator<< ( std::ostream &os, const RV &rv ) {
	int id;
	for ( int i = 0; i < rv.len ; i++ ) {
		id = rv.ids ( i );
		os << id << "(" << RV::SIZES ( id ) << ")" <<  // id(size)=
		"=" << RV::NAMES ( id )  << "_{"  << rv.times ( i ) << "}; "; //name_{time}
	}
	return os;
}

str RV::tostr() const {
	ivec idlist ( dsize );
	ivec tmlist ( dsize );
	int i;
	int pos = 0;
	for ( i = 0; i < len; i++ ) {
		idlist.set_subvector ( pos, pos + size ( i ) - 1, ids ( i ) );
		tmlist.set_subvector ( pos, pos + size ( i ) - 1, times ( i ) );
		pos += size ( i );
	}
	return str ( idlist, tmlist );
}

ivec RV::dataind ( const RV &rv2 ) const {
	ivec res ( 0 );
	if ( rv2._dsize() > 0 ) {
		str str2 = rv2.tostr();
		ivec part;
		int i;
		for ( i = 0; i < len; i++ ) {
			part = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
			res = concat ( res, part );
		}
	}

	bdm_assert_debug ( res.length() == dsize, "this rv is not fully present in crv!" );
	return res;

}

void RV::dataind ( const RV &rv2, ivec &selfi, ivec &rv2i ) const {
	//clean results
	selfi.set_size ( 0 );
	rv2i.set_size ( 0 );

	// just in case any rv is empty
	if ( ( len == 0 ) || ( rv2.length() == 0 ) ) {
		return;
	}

	//find comon rv
	ivec cids = itpp::find ( this->findself ( rv2 ) >= 0 );

	// index of
	if ( cids.length() > 0 ) {
		str str1 = tostr();
		str str2 = rv2.tostr();

		ivec part1;
		ivec part2;
		int i, j;
		// find common rv in strs
		for ( j = 0; j < cids.length(); j++ ) {
			i = cids ( j );
			part1 = itpp::find ( ( str1.ids == ids ( i ) ) & ( str1.times == times ( i ) ) );
			part2 = itpp::find ( ( str2.ids == ids ( i ) ) & ( str2.times == times ( i ) ) );
			selfi = concat ( selfi, part1 );
			rv2i = concat ( rv2i, part2 );
		}
	}
	bdm_assert_debug ( selfi.length() == rv2i.length(), "this should not happen!" );
}

RV RV::subt ( const RV &rv2 ) const {
	ivec res = this->findself ( rv2 ); // nonzeros
	ivec valid;
	if ( dsize > 0 ) {
		valid = itpp::find ( res == -1 );    //-1 => value not found => it remains
	}
	return ( *this ) ( valid ); //keep those that were not found in rv2
}

ivec RV::findself ( const RV &rv2 ) const {
	int i, j;
	ivec tmp = -ones_i ( len );
	for ( i = 0; i < len; i++ ) {
		for ( j = 0; j < rv2.length(); j++ ) {
			if ( ( ids ( i ) == rv2.ids ( j ) ) & ( times ( i ) == rv2.times ( j ) ) ) {
				tmp ( i ) = j;
				break;
			}
		}
	}
	return tmp;
}

ivec RV::findself_ids ( const RV &rv2 ) const {
	int i, j;
	ivec tmp = -ones_i ( len );
	for ( i = 0; i < len; i++ ) {
		for ( j = 0; j < rv2.length(); j++ ) {
			if ( ( ids ( i ) == rv2.ids ( j ) ) ) {
				tmp ( i ) = j;
				break;
			}
		}
	}
	return tmp;
}

void RV::from_setting ( const Setting &set ) {
	Array<string> A;
	UI::get ( A, set, "names" );

	ivec szs;
	if ( !UI::get ( szs, set, "sizes" ) )
		szs = ones_i ( A.length() );

	ivec tms;
	if ( !UI::get ( tms, set, "times" ) )
		tms = zeros_i ( A.length() );

	// TODO tady se bude plnit primo do jeho promennych, a pak se zavola validacnni metoda, takze cele prepsat, ano?
	init ( A, szs, tms );
}

RV concat ( const RV &rv1, const RV &rv2 ) {
	RV pom = rv1;
	pom.add ( rv2 );
	return pom;
}

RV get_composite_rv ( const Array<shared_ptr<pdf> > &pdfs,
                      bool checkoverlap ) {
	RV rv; //empty rv
	bool rvaddok;
	for ( int i = 0; i < pdfs.length(); i++ ) {
		rvaddok = rv.add ( pdfs ( i )->_rv() ); //add rv to common rvs.
		// If rvaddok==false, pdfs overlap => assert error.
		bdm_assert_debug ( rvaddok || !checkoverlap, "mprod::mprod() input pdfs overlap in rv!" );
	}

	return rv;
}

void BM::bayes_batch ( const mat &Data, const vec &cond ) {
	for ( int t = 0; t < Data.cols(); t++ ) {
		bayes ( Data.get_col ( t ), cond );
	}
}
void BM::bayes_batch ( const mat &Data, const mat &Cond ) {
	for ( int t = 0; t < Data.cols(); t++ ) {
		bayes ( Data.get_col ( t ), Cond.get_col(t) );
	}
}
}
