#include <NumCalc/MatrixTools.h>
#include <NumCalc/VectorTools.h>
#include <NumCalc/EigenValue.h>

#include "BasePairModel.h"
#include "Log.h"

namespace ptr
{

BasePairModel::BasePairModel(BasePairAlphabet * a)  : bpp::AbstractReversibleSubstitutionModel(a, "BasePairModel.") {

	bpa = a;
	
	row["AA"] = 6; rrow[6] = "AA";
	row["AC"] = 7; rrow[7] = "AC";
	row["AG"] = 8; rrow[8] = "AG";
	row["AT"] = 0; rrow[0] = "AT";
	row["AU"] = 0; 
	row["CA"] = 9;  rrow[9] = "CA";	
	row["CC"] = 10; rrow[10] = "CC";
	row["CG"] = 5;  rrow[5] = "CG";	
	row["CT"] = 11;  rrow[11] = "CT";
	row["CU"] = 11; 
	row["GA"] = 12; rrow[12] = "GA";	
	row["GC"] = 2; rrow[2] = "GC";
	row["GG"] = 13; rrow[13] = "GG";	
	row["GT"] = 1; rrow[1] = "GT";
	row["GU"] = 1; 
	row["TA"] = 3; rrow[3] = "TA";	
	row["TC"] = 14; rrow[14] = "TC";
	row["TG"] = 4; rrow[4] = "TG";	
	row["TT"] = 15; rrow[15] = "TT";
	row["UA"] = 3; 	
	row["UC"] = 14; 
	row["UG"] = 4; 	
	row["UU"] = 15; 



	LOG(lTRACE) << exchangeability_.getNumberOfRows() << " " << exchangeability_.getNumberOfColumns();
	// eukaryotes
	//                AU     GU    GC   UA   UG      CG      AA      AC    AG     CA      CC     CU     GA     GG    UC   UU
	double x[] =  {   0  ,  0  ,  0  ,  0   ,  0   ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0, 
		             6.4 ,  0  ,  0  ,  0   ,  0   ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0, 
		             4.5 , 6.2 ,  0  ,  0   ,  0   ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             3.1 , 1.2 , 1.7 ,  0   ,  0   ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             1.4 , 1.3 , 1.3 , 5.2  ,  0   ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             2.1 , 1.1 , 1.7 , 5.6  , 6.4  ,  0   ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             5.8 , 3.5 , 2.1 , 6.4  , 9.2  , 1.5  ,   0   ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             13.3, 13.7, 7.5 , 1.6  , 8.3  , 0.1  , 50.2  ,  0   ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             10.7, 6.6 , 4.0 , 3.3  , 15.3 , 7.0  , 110.5 , 32.5 ,  0   ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             0.8 , 1.2 , 1.6 , 13.5 , 5.8  , 11.2 , 9.1   , 6.9  , 2.5  ,  0   ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             2.8 , 3.2 , 6.1 , 1.4  , 3.9  , 4.4  , 1.9   , 42.9 , 6.2  , 3.1  ,   0  ,  0   ,   0  ,  0   ,  0 ,  0,
		             4.8 , 5.0 , 1.4 , 3.1  , 8.9  , 3.8  , 4.1   , 34.0 , 12.1 , 81.8 , 46.0 ,  0   ,   0  ,  0   ,  0 ,  0,
		             1.4 , 5.0 , 1.8 , 2.2  , 1.4  , 0.7  , 26.4  , 5.7  , 13.4 , 7.9  , 3.2  , 2.6  ,   0  ,  0   ,  0 ,  0,
		             3.3 , 12.4, 6.6 , 3.3  , 9.2  , 7.0  , 37.0  , 18.6 , 32.3 , 7.7  , 12.8 , 10.3 , 15.0 ,  0   ,  0 ,  0, 
		             9.0 , 6.1 , 5.2 , 6.6  , 11.0 , 2.6  , 69.2  , 81.0 , 18.6 , 10.5 , 55.5 , 54.5 , 0.7  , 39.6 ,  0 ,  0,
		             6.5 , 5.1 , 2.0 , 3.1  , 3.9  , 2.5  , 3.6   , 15.0 , 8.2  , 4.4  , 31.6 , 58.3 , 1.3  , 7.7  , 50.4, 0 };         
	
	// fill in the upper triangle of the matrix:
	for(int row = 1; row < 16; row++) {
		for(int col = 0; col <= row; col++) {
			int index = row * 16 + col;
			int index_twin = index - ( (16-1)*row - (16-1)*col );
			x[index_twin] = x[index];	
		}
	}
		
	// sum of rows = sum of columns
	for(int col = 0; col < 16; col++) {
		int e_col = a->charToInt(rrow[col]);
		double sum = 0;
		for(int row = 0; row < 16; row++) {
			sum += x[row*16 + col];
			int e_row = a->charToInt(rrow[row]);
			exchangeability_(e_row,e_col) = x[row*16 + col]; 
			exchangeability_(e_col,e_row) = x[row*16 + col];
		}
		int index = col * (16+1);
		x[index] = -1 * sum;	
		exchangeability_(e_col,e_col) = -1 *sum;
	}
	
	//cout << exchangeability_ << endl;
/*	for(int col = 0; col < 16; col++) {
		std::cout << "\t" << a->intToChar(col);	
	}
	std::cout << std::endl;
	
	for(int row = 0; row < 16; row++) {
		std::cout << a->intToChar(row);	
		for(int col = 0; col < 16; col++) {
	 		std::cout << "\t" << exchangeability_(row,col);
		}
		std::cout << std::endl;
	}
 */	
}

void BasePairModel::setFreq(float * f) {
	
	LOG(lTRACE) << "BasePairModel::setFreq (pseudoCount) start"; 
	
   for(unsigned int i = 0; i < size_; i++) {
   	 freq_[bpa->charToInt(rrow[i])] = f[i];
   	 //LOG(lTRACE) << i << " " << rrow[i] << " "  << bpa->charToInt(rrow[i]) << " " << f[i] << " " <<  freq_[bpa->charToInt(rrow[i])];
   }
	LOG(lTRACE) << "BasePairModel::setFreq (pseudoCount) end";
	
}

BasePairModel::~BasePairModel() {

	LOG(lTRACE2) << "BasePairModel destructor start"; 
	if(bpa != NULL)
		delete bpa;	
	LOG(lTRACE2) << "BasePairModel destructor end"; 
}




}
