#include "gsl_ext.h" 


//Add values represented at the log scale
double add_log(double loga, double logb)
{
  double min,max;
  double sum;
  if(loga < logb){
    min = loga;
    max = logb;
  }
  else{
    min = logb;
    max = loga;
  }
  if ( max - min  > log(INF)){
    sum=max;
  }
  else{
    sum=max + log( 1 + exp(min -max));
  }
  //printf(" %f + %f = %f\n", loga, logb, sum);
  return sum;
}


gsl_vector * vector_init(double* d, int size)
{
  int i;
  gsl_vector * v = gsl_vector_alloc(size);  
  for(i=0; i < v->size; i++){
    gsl_vector_set(v, i, d[i]);
  }
  return v;
}


gsl_matrix * matrix_init(double** d, int size1,int size2)
{
  int i,j;
  gsl_matrix * m = gsl_matrix_alloc(size1,size2);  
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){
      gsl_matrix_set(m, i, j, d[i][j]);
    }
  }
  return m;
}





gsl_vector * matrix_vector_mul(gsl_matrix *m, gsl_vector *v)
{
  assert(v->size == m->size2);
  gsl_vector * y = gsl_vector_alloc(m->size1);  
  matrix_vector_mul_eq(y, m,v);
  return y;
}

void matrix_vector_mul_eq(gsl_vector* y, gsl_matrix* m, gsl_vector* v)
{
  int i,j;
  for(i=0; i < m->size1; i++){
    double sum=0;
    for(j=0; j < m->size2; j++){
      sum += gsl_matrix_get(m,i,j) * gsl_vector_get(v,j);
    }
    gsl_vector_set(y,i,sum);
  }
}


gsl_vector * matrix_vector_mul_log(gsl_matrix *m, gsl_vector *v)
{
  gsl_vector * y = gsl_vector_alloc(m->size1);  
  matrix_vector_mul_log_eq(y, m, v);
  return y;
}


void matrix_vector_mul_log_eq(gsl_vector *y, gsl_matrix *m, gsl_vector *v)
{
  int i,j;
  for(i=0; i < m->size1; i++){
    double sum=-INF;
    for(j=0; j < m->size2; j++){
      sum= add_log(sum, gsl_matrix_get(m,i,j) + gsl_vector_get(v,j));
    }
    gsl_vector_set(y,i,sum);
  }
}



gsl_vector * vector_matrix_mul(gsl_matrix *m, gsl_vector *v)
{
  int i,j;
  assert(v->size == m->size1);
  gsl_vector * y = gsl_vector_alloc(m->size2);  
  for(j=0; j < m->size2; j++){
    double sum=0;
    for(i=0; i < m->size1; i++){
      sum += gsl_matrix_get(m,i,j) * gsl_vector_get(v,i);
    }
    gsl_vector_set(y,j,sum);
  }
  return y;
}

gsl_vector * vector_matrix_mul_log(gsl_matrix *m, gsl_vector *v)
{
  int i,j;
  assert(v->size == m->size1);
  gsl_vector * y = gsl_vector_alloc(m->size2);  
  for(j=0; j < m->size2; j++){
    double sum=-INF;
    for(i=0; i < m->size1; i++){
      sum = add_log(sum, gsl_matrix_get(m,i,j) + gsl_vector_get(v,i));
    }
    gsl_vector_set(y,j,sum);
  }
  return y;
}


double vector_vector_mul(gsl_vector *u, gsl_vector *v)
{
  int i;
  double sum=0;
  assert(u->size == v->size);
  for(i=0; i < u->size; i++){
    sum += gsl_vector_get(u,i) * gsl_vector_get(v,i);
  }
  return sum;
}


double vector_vector_mul_log(gsl_vector *u, gsl_vector *v)
{
  int i;
  double sum=-INF;
  for(i=0; i < u->size; i++){
    sum = add_log(sum, gsl_vector_get(u,i) + gsl_vector_get(v,i));
  }
  return sum;
}

gsl_matrix * matrix_diag_mul(gsl_matrix *m, gsl_vector *D)
{
  gsl_matrix *n = gsl_matrix_alloc(m->size1, m->size2);  
  matrix_diag_mul_eq(n, m, D);
  return n;
}


void matrix_diag_mul_eq(gsl_matrix *ret_m, gsl_matrix *m, gsl_vector *D)
{
  int i,j;
  assert(D->size == m->size2);
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){
      gsl_matrix_set(ret_m, i, j, gsl_matrix_get(m, i,j) * gsl_vector_get(D,j));    
    }
  }
}

gsl_matrix * diag_matrix_mul(gsl_matrix *m, gsl_vector *D)
{
  gsl_matrix *n = gsl_matrix_alloc(m->size1, m->size2);  
  matrix_diag_mul_eq(n, m, D);
  return n;
}

void diag_matrix_mul_eq(gsl_matrix *ret_m, gsl_matrix *m, gsl_vector *D)
{ 
  int i,j;
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){
      gsl_matrix_set(ret_m, i, j, gsl_matrix_get(m, i,j) * gsl_vector_get(D,i));    
    }
  }
}

gsl_matrix * matrix_matrix_mul(gsl_matrix *m, gsl_matrix *n)
{
  int i,j,k;
  assert(m->size2 == n->size1);
  gsl_matrix* r = gsl_matrix_alloc(m->size1, n->size2);
  for(i=0; i < m->size1; i++){
    for(j=0; j < n->size2; j++){
      double sum=0;
      for(k=0; k < n->size1; k++)
	sum += gsl_matrix_get(m, i,k) * gsl_matrix_get(n,k,j);    
      gsl_matrix_set(r, i, j, sum);
    }
  }
  return r;
}



gsl_matrix* matrix_exp(gsl_vector* D, gsl_matrix* Q, double e)
{
  int i,j;  
  //printf("Exp %f\n", e);
  gsl_vector *D_exp = gsl_vector_alloc(D->size);
  for(i=0; i < D->size; i++){
    gsl_vector_set(D_exp, i, exp(gsl_vector_get(D,i)*e));    
  }
  
  /*
  printf("D\n");
  printf("\n");
  vector_fprintf(stdout, D);
  
  printf("D_exp\n");
  printf("\n");
  vector_fprintf(stdout, D_exp);
  */

  gsl_matrix *temp=matrix_diag_mul(Q,D_exp);
  
  //printf("Q * D_exp\n");
  //matrix_fprintf(stdout, temp);
  gsl_matrix *Q_temp = gsl_matrix_alloc(Q->size2, Q->size1);
  gsl_matrix_transpose_memcpy(Q_temp, Q);
  gsl_matrix *result = matrix_matrix_mul(temp, Q_temp);

  //printf("Q * D_exp * Qt\n");
  //matrix_fprintf(stdout, result);
  
  gsl_matrix_free(Q_temp);
  gsl_matrix_free(temp);
  return result;
}

void matrix_log(gsl_matrix* m)
{
  int i,j;  
  double d;
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){      
      d= gsl_matrix_get(m,i,j);
      if (d < PRECISION)
	d= - INF;
      else
	d= log(d);
      gsl_matrix_set(m,i,j,d);
    }
  }  
}

void vector_log(gsl_vector* v)
{
  int i;  
  for(i=0; i < v->size; i++){
    double d= gsl_vector_get(v,i);
    if (d < PRECISION)
      d= - INF;
    else if ( d < PRECISION + 1 && d > 1- PRECISION){
      d = 0;
    }
    else
      d= log(d);
    if (!d == d){
      Die("Error d %f orig %f\n", log(d), gsl_vector_get(v,i));
    }
    gsl_vector_set(v,i,d);
  }  
}

void matrix_unlog(gsl_matrix* m)
{
  int i,j;  
  double d;
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){      
      d= gsl_matrix_get(m,i,j);
      d = exp(d);
      gsl_matrix_set(m,i,j,d);
    }
  }  
}


void matrix_fprintf (FILE * stream,  gsl_matrix * m)
{
  int i,j;
  for(i=0; i < m->size1; i++){
    for(j=0; j < m->size2; j++){      
      fprintf(stream, " %.3f", gsl_matrix_get(m, i,j));
    }
    fprintf(stream, "\n");	
  }
}

void vector_fprintf (FILE * stream, gsl_vector * v)
{
  int i;
  for(i=0; i < v->size; i++){
    fprintf(stream,  " %.4f", gsl_vector_get(v,i));
  }
  fprintf(stream, "\n");
}

void diag_fprintf (FILE * stream, gsl_vector * v)
{
  int i,j;
  for(i=0; i < v->size; i++){
    fprintf(stream, "%d", i);
    for(j=0; j < v->size; j++){
      if (i==j){
	fprintf(stream, " %.2f", gsl_vector_get(v,i));
      }
      else{
	fprintf(stream, "    ");
      }
    }
    fprintf(stream, "\n");	
  }
}

double vector_norm(gsl_vector *v) {
  double ss = 0;
  int i;
  for (i = 0; i < v->size; i++)
    ss += v->data[i] * v->data[i];
  return sqrt(ss);
}

double vector_sum(gsl_vector *v)
{
  double s = 0;
  int i;
  for (i = 0; i < v->size; i++)
    s += gsl_vector_get(v, i);
  return s;
}

double matrix_sum(gsl_matrix *v)
{
  double s = 0;
  int i,j;
  for (i = 0; i < v->size1; i++)
    for (j = 0; j < v->size2; j++)
      s += gsl_matrix_get(v, i,j);
  return s;
}

gsl_matrix* vector_outer_prod(gsl_vector* v1, gsl_vector* v2)
{

  gsl_matrix* m = gsl_matrix_alloc(v1->size, v2->size);  
  vector_outer_prod_eq(m, v1, v2);
  return m;
}


void vector_outer_prod_eq(gsl_matrix* m, gsl_vector* v1, gsl_vector* v2)
{
  int i, j;
  for (i = 0; i < v1->size; i++) {
    for (j = 0; j < v2->size; j++) {
      gsl_matrix_set(m, i, j, gsl_vector_get(v1, i) * gsl_vector_get(v2, j));
    }
  }  
}


int try_main()
{
  int dim = 4;
  int i,j;
  double d[][4] ={ 
    {-0.75,  0.26,  0.32,  0.16},  //A
    {0.35,  -1.05,  0.18,  0.51},  //U
    {0.55,  0.24,  -0.96,  0.17},  //G
    {0.40,  0.93,   0.24,  -1.56}  //C

  };
  double ** t = malloc(sizeof(double*)* dim);
  for(i=0; i < dim; i++){
    t[i] = malloc(sizeof(double)* dim);
    for(j=0; j < dim; j++){
      t[i][j]= d[i][j];
    }
  }
  double equiFreq[4]={
    0.36, 0.27, 0.21, 0.15
  };

  gsl_matrix* m = matrix_init(t,dim,dim);  
  gsl_vector * PI_Sqrt = gsl_vector_alloc(dim);
  gsl_vector * PI_MinusSqrt = gsl_vector_alloc(dim);
  for(i=0; i < dim; i++){    
    gsl_vector_set(PI_Sqrt,i, sqrt(equiFreq[i]));
    gsl_vector_set(PI_MinusSqrt,i, 1.0/sqrt(equiFreq[i]));
  }
  printf("M \n");
  matrix_fprintf(stdout, m);
  printf("PI_Sqrt\n");
  vector_fprintf(stdout, PI_Sqrt);
  printf("PI_MinusSqrt\n");
  vector_fprintf(stdout, PI_MinusSqrt);

  gsl_matrix * temp  =  diag_matrix_mul(m, PI_Sqrt);
  printf("PI_Sqrt * M \n");
  matrix_fprintf(stdout, temp); 
  gsl_matrix * S = matrix_diag_mul(temp, PI_MinusSqrt);  
  printf("S \n");
  matrix_fprintf(stdout, S);
  gsl_matrix * S_clone = gsl_matrix_alloc(S->size1, S->size2);
  gsl_matrix_memcpy(S_clone, S);
  gsl_vector* D = gsl_vector_alloc(S->size1);
  gsl_vector* work = gsl_vector_alloc(S->size1);
  gsl_matrix* Q = gsl_matrix_alloc(S->size1,S->size2);  
  gsl_linalg_SV_decomp (S_clone, Q, D, work);
  if (gsl_matrix_get(S_clone, 0,0) * gsl_matrix_get(Q, 0,0) < 0){
    gsl_vector_scale(D,-1);    
  }

  printf("D\n");
  diag_fprintf(stdout, D);
  printf("Q\n");
  matrix_fprintf(stdout, Q);


  gsl_matrix* exp=matrix_exp(D, Q, 3);
  printf("Exp 3\n");
  matrix_fprintf(stdout, exp);  
  temp  =  diag_matrix_mul(exp, PI_MinusSqrt);
  gsl_matrix * trans = matrix_diag_mul(temp, PI_Sqrt);
  printf("Transition \n");
  matrix_fprintf(stdout, trans);
}


/*
int main()
{
  int dim = 4;
  int i,j;
  double d[][4] ={ 
    {0.75,  0.26,  0.32,  0.16},  //A
    {0.35,  1.05,  0.18,  0.51},  //U
    {0.55,  0.24,  0.96,  0.17},  //G
    {0.40,  0.93,  0.24,  1.56}  //C
  };
  double ** t = malloc(sizeof(double*)* dim);
  for(i=0; i < dim; i++){
    t[i] = malloc(sizeof(double)* dim);
    for(j=0; j < dim; j++){
      t[i][j]= d[i][j];
    }
  }
  double u[4] ={0.55,  0.26,  0.32,  0.16};
  gsl_matrix* m = matrix_init(t, dim,dim);  
  gsl_vector* v = vector_init(u, dim);
  gsl_vector* result1 = vector_matrix_mul(m, v);
  printf("result1 ");
  vector_fprintf(stdout, result1);
  matrix_log(m);
  vector_log(v);
  gsl_vector* result2 = vector_matrix_mul_log(m,v);
  printf("result2 ");
  vector_fprintf(stdout, result2);
  //vector_unlog(result2);
  //vector_fprintf(stdout, result2);
}
*/
