#include "phytree.h"
#include "squid.h"
#include "structs.h"
#include "funcs.h"
#include "msa.h"
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include "evolve.h"
#include "matrix.h"
#include "../cmfinder/global.h"
#include "gsl/gsl_matrix.h"
#include "gsl/gsl_vector.h"
#include "gsl/gsl_linalg.h"

#define MAXLINE 1000

int logscale=1;
EvoModel* single_model[MAX_MODEL];
EvoModel* pair_model[MAX_MODEL];
int single_model_num;
int pair_model_num;

char* ParseName(char** s, char* name)
{
  int  length;  
  char* begin = *s;  
  char* tmp;  
  //while( isalnum(**s) || **s == '_' ) (*s)++;  
  while( !isspace(**s) &&  (**s!= ':') && (**s != '(') && (**s!= ')'))(*s)++;  
  length = (*s)-begin;
  if (name == NULL)
    name = malloc(sizeof(char) * (length + 1));  

  strncpy(name, begin, length);
  name[length] = '\0';
  return name;  
}

double ParseDouble(char** s)
{
  double d= strtod(*s, s);
  if ( *s== NULL) 
    fprintf(stderr, "Not double value found!\n");  
  return d;  
}


PhyNode* ParseTree(char** ptr)
{
    
  double l, l1, l2;
  PhyNode *child, *node;  
  int    i;  
  
  while(isspace(**ptr)) (*ptr)++;    
    
    if (**ptr == '(') {    
      (*ptr)++;

      node = (PhyNode*) MallocOrDie(sizeof(PhyNode));
      memset(node, 0, sizeof(PhyNode));
      
      while(**ptr != ')' ) {	
	while(isspace(**ptr)) (*ptr)++;    
	child = ParseTree(ptr);     
	
	while(**ptr != ':') (*ptr)++;
	(*ptr)++; 
	l = ParseDouble(ptr);
	
	node->children[node->nchildren++] = child;
	strcat(node->name, child->name);
	child->parent  = node;
	child->length = l;
	
	node->size += child->size;
		
	//skip while spaces and ","
	while(**ptr != ',' && **ptr != ')') (*ptr)++; 
	if (**ptr == ',') (*ptr)++; 
      }      
      
    }    
    else{        
      node = (PhyNode*) malloc(sizeof(PhyNode));
      memset(node, 0, sizeof(PhyNode));    
      ParseName(ptr, node->name);      
      node->size = 1;      
    }    
    return node;      
}


void CalcWeight(PhyNode* node)
{
  PhyNode** stack = (PhyNode**) malloc(sizeof(PhyNode*) * node->size);
  PhyNode** leaves = (PhyNode**) malloc(sizeof(PhyNode*) * node->size);
  PhyNode** sp;
  PhyNode* curr;
  int      i, leaf_count=0;
  float    max_weight=0;
  
  stack[0] = node;
  sp = stack + 1;
  while( sp > stack) {
    sp --;
    curr = *sp;
    curr->weight = curr->length / curr->size;
    if (curr->parent) 
      curr->weight += curr->parent->weight;        
    
    for(i=0; i < curr->nchildren; i++){
      *(sp++) = curr->children[i];
    }

    if (curr->nchildren == 0){      
      leaves[leaf_count ++] = curr;      
      if (curr->weight > max_weight)
	max_weight = curr->weight;      
    }
    assert(sp <= stack + node->size);    
  }
  assert(leaf_count <= node->size);  
  for(i=0; i < leaf_count; i++) {
    leaves[i]->weight /= max_weight;    
  }    
  free(stack);
  free(leaves);  
}


void PrintTree(PhyNode* node, char* indent)
{
  int i;  
  char ind[100];  
  if (node == NULL)
    return;
  
  printf("%s %s (%1.3f %1.3f)\n", indent, node->name, node->length, node->weight);
  strcpy(ind, indent);  
  strcat(ind, "\t");
  for(i=0;i < node->nchildren; i++) {    
    PrintTree(node->children[i], ind);
  }  
}


PhyNode* ReadPhyFile(char* filename)
{
  FILE* fin;
  char  buffer[MAXLINE+1];  
  char  *str= NULL, *temp;  
  int   l = 1;  
  PhyNode* tree;  
  if ( (fin = fopen(filename, "r")) == NULL) {
    printf("Fail to open file %s", filename);    
    exit(1);    
  }
  while( fgets(buffer, MAXLINE, fin) > 0) 
    {
      l += strlen(buffer) ;
    }
  fclose(fin);
  
  if ((str = malloc(sizeof(char) * l)) == NULL){
    printf("Malloc error at ReadPhyFile \n");    
    exit(1);
  }
  memset(str, 0, sizeof(char) * l);
  
  if ( (fin = fopen(filename, "r")) == NULL) {
    printf("Fail to open file %s", filename);    
    exit(1);    
  }
  
  while( fgets(buffer, MAXLINE, fin) > 0 )
    strcat(str, buffer);      

  fclose(fin);  
    
  temp = str;  
  tree = ParseTree(&temp);  
  free(str);
  return (tree);
}


PhyNode** PhyLeaves(PhyNode* node)
{
  PhyNode** stack =  (PhyNode**) malloc(sizeof(PhyNode*) * node->size);
  PhyNode** leaves = (PhyNode**) malloc(sizeof(PhyNode*) * node->size);
  PhyNode** sp;
  PhyNode* curr;
  int      leaf_count=0;
  int      i=0;  
  stack[0] = node;
  sp = stack + 1;
  while( sp > stack) {
    sp --;
    curr = *sp;    
    for(i=0; i < curr->nchildren; i++){      
      *(sp++) = curr->children[i];
    }
    
    if ( curr->nchildren == 0) {      
      leaves[leaf_count ++] = curr;      
    }    
  }
  free(stack);  
  return leaves;  
}

void PrintNodes(PhyNode** nodes, int size)
{
  int i;  
  for(i=0; i < size; i++) 
    printf("%30s\t%d\t%f\t%f\n", nodes[i]->name,  nodes[i]->size,  nodes[i]->length, nodes[i]->weight);
    
}


void FreePhytree(PhyNode* node)
{
  int i;  
  if (node == NULL) return;
  if (node->tm){
    gsl_matrix_free(node->tm);
  }
  if (node->likelihood){
    gsl_vector_free(node->likelihood);
  }
  if (node->freq){
    gsl_vector_free(node->freq);
  }
  
  for(i=0; i < node->nchildren; i++) {    
    FreePhytree(node->children[i]);
  }
  free(node);    
}


PhyNode* get_root(PhyNode* node)
{
  PhyNode* curr= node;
  if (!node) return NULL;  
  while(curr->parent) curr = curr->parent;
  return curr;
}

EvoModel* EvoModel_alloc(int mode,  int use_gap)
{
  EvoModel * m = (EvoModel*)malloc(sizeof(EvoModel));
  memset(m, 0, sizeof(EvoModel));
  int dim = mode == SINGLE ? Alphabet_size : Alphabet_size * Alphabet_size;
  m -> dim = dim;
  m -> scale = 1;
  m -> freq = gsl_vector_alloc(dim);  
  m -> rm = gsl_matrix_alloc(dim, dim);    
  if (use_gap){
    if (mode == SINGLE){
      m-> gap_rm = gsl_matrix_alloc(dim+1, dim+1);
      m-> gap_params = gsl_vector_alloc(1);
    }
    else{
      m-> gap_params = gsl_vector_alloc(3);       
      m-> gap_rm = gsl_matrix_alloc((Alphabet_size+1)*(Alphabet_size + 1), 
				    (Alphabet_size+1)*(Alphabet_size + 1));
    }
  }
  return m;
}


void marginize(EvoModel* mod_pair, gsl_matrix* rm, gsl_vector* freq)
{
  int i,j,i1,i2,j1,j2,k,l;
  double v,u,frac;
  gsl_vector* freq_left = freq;
  gsl_vector* freq_right = gsl_vector_alloc(Alphabet_size);
  gsl_vector_set_all(freq_left,0);
  gsl_vector_set_all(freq_right,0);
  
  for(i=0; i < Alphabet_size; i++){    
    for(j=0; j < Alphabet_size; j++){      
      v = gsl_vector_get(mod_pair->freq, Alphabet_size * i + j);
      gsl_vector_set(freq_left, i, v+ gsl_vector_get(freq_left, i));
      gsl_vector_set(freq_right, j, v+ gsl_vector_get(freq_right, j));
    }
  }

  gsl_vector_add(freq, freq_right);
  gsl_vector_scale(freq, 0.5);
  
  for(k=0; k < mod_pair->dim; k++){
    for(l=0; l < mod_pair->dim; l++){
      int i1,i2, j1, j2;
      i1 = k / Alphabet_size;
      i2 = k % Alphabet_size;
      j1 = l / Alphabet_size;
      j2 = l % Alphabet_size;

      u = gsl_matrix_get(mod_pair->rm, k, l);
      
      v= gsl_matrix_get( rm, i1, j1);
      frac = gsl_vector_get(mod_pair->freq, k)/gsl_vector_get(freq,i1);      
      gsl_matrix_set(rm, i1, j1, v + u * frac);
		     
      //printf("pair %d-%d, single %d-%d \t orig_rm %f  pair_rm %f  frac %f\t rm %f\n",
      //k,l, i1,j1, v,u, frac, v + u * frac);

      v= gsl_matrix_get( rm, i2, j2);
      frac= gsl_vector_get(mod_pair->freq, k)/gsl_vector_get(freq,i2);
      gsl_matrix_set( rm, i2, j2, v + u * frac);
      
      
      //printf("pair %d-%d, single %d-%d \t orig_rm %f  pair_rm %f  frac %.3f\t rm %f\n",
      //k,l, i2,j2, v,u, frac, v + u * frac);
    }
  }
}

void EvoModel_init_single_gap(EvoModel* mod)
{
  int i,j;
  int dim = mod->dim;
  gsl_matrix* gap_rm = mod->gap_rm;
  gsl_matrix_set_all(gap_rm, 0);
  double gap_beta = gsl_vector_get(mod->gap_params, 0);

  
  for(i=0; i < dim; i++){
    for(j=0; j < dim; j++){
      if (i==j){
	gsl_matrix_set(gap_rm, i, i, gsl_matrix_get(mod->rm, i,i) - gap_beta);
      }
      else{
	gsl_matrix_set(gap_rm, i, j, gsl_matrix_get(mod->rm, i,j));
      }
    }
    gsl_matrix_set(gap_rm, i, dim, gap_beta);
  }

  /*
  double gap_q0 = gsl_vector_get(mod->gap_params, 1);
  gsl_matrix* gap_Q0 = mod->gap_Q0;  
  gsl_matrix_set_all(gap_Q0, 0);  
  for(i=0; i < dim; i++){
    gsl_matrix_set(gap_Q0, i, i, 1);
    gsl_matrix_set(gap_Q0, dim, i, (1- gap_q0)* gsl_vector_get(mod->freq, i));
  }
  gsl_matrix_set(gap_Q0, dim, dim, gap_q0);      
  */
}


void EvoModel_init_pair_gap(EvoModel* mod)
{
  int i,j,c1,c2,d1,d2;
  int dim = mod->dim;
  int L = Alphabet_size;
  int L5 = Alphabet_size +1;
  
  gsl_matrix* gap_rm = mod->gap_rm;
  gsl_matrix_set_all(gap_rm, 0);
  double beta1 = gsl_vector_get(mod->gap_params, 0);
  double beta2 = gsl_vector_get(mod->gap_params, 1);
  double beta3 = gsl_vector_get(mod->gap_params, 2);
  

  gsl_matrix* single_rm = gsl_matrix_alloc(L, L);
  gsl_vector* single_freq = gsl_vector_alloc(L);
  marginize(mod, single_rm, single_freq);  
  
  //matrix_fprintf(stdout, single_rm);

  for(i=0; i < L5 * L5; i++)
    for(j=0; j < L5 * L5; j++){
      c1 = i/L5;
      c2 = i%L5;
      d1 = j/L5;
      d2 = j%L5;

      double p=0;
      if (c1==L && c2 ==L){
	p = 0;
      }
      else{
	if (c1 < L && c2 < L && d1 < L && d2 < L){
	  p = gsl_matrix_get(mod->rm, c1 * L + c2, d1 * L + d2);
	  if (c1==d1 && c2 == d2){
	    p -= beta1 - 2* beta2;
	  }
	}
	else if (c1 < L && c2 < L && d1 == L && d2 == L){
	  p = beta1;
	}
	else if (c1 < L && c2 < L && (c1 == d1 || c2 == d2)){
	  p = beta2;
	}
	else if (c1 == L && d1 ==L){
	  if ( d2 == L)
	    p = beta3;
	  else{
	    p = gsl_matrix_get(single_rm, c2, d2);
	    if (c2 == d2) p -= beta3;
	  }
	}
	else if (c2 == L && d2 ==L){
	  if (d1==L )
	    p = beta3;
	  else{
	    p = gsl_matrix_get(single_rm, c1,d1);
	    if (c1 == d1) p -= beta3;
	  }
	}
      }
      gsl_matrix_set(gap_rm, c1 *L5 + c2, d1 * L5 + d2, p);	      
    }
}



int EvoModel_read(char* filename, EvoModel* models[])
{

  //Read rm and freq
  FILE * fin;
  char buffer[1000];
  char rm_tag[] = "Rate Matrix";
  char freq_tag[] = "Frequency";
  char model_tag[] = "Model";
  char scale_tag[] = "Scale";
  char gap_tag[] = "Gap";
  char dim_tag[] = "DIM";
  int  i;
  int  read_matrix=0;
  int  read_freq=0;
  int  nmodel = 0;
  float scale;
  int   dim;
  EvoModel * m;
  
  if ( (fin = fopen(filename, "r") ) == NULL) 
    Die("Fail to read file %s", filename);

  while(fgets(buffer,MAXLINE, fin) > 0) {    
    if (strncmp(buffer, model_tag, strlen(model_tag))==0){
      read_matrix=0;
      read_freq = 0;
      m = NULL;
      nmodel++;      
    }
    
    if (strncmp(buffer, dim_tag, strlen(dim_tag))== 0){
      if (fgets(buffer,MAXLINE, fin) <= 0) continue;
      sscanf(buffer, "%d", &dim);                  
      if (dim == Alphabet_size){
	m = EvoModel_alloc(SINGLE,1);
      }
      else{
	m = EvoModel_alloc(PAIR,1);
      }
    }

    if (strncmp(buffer, rm_tag, strlen(rm_tag))== 0){
      gsl_matrix_fscanf(fin, m->rm);
      read_matrix=1;      
    }
    
    if (strncmp(buffer, freq_tag, strlen(freq_tag))== 0){
      gsl_vector_fscanf(fin, m->freq);
      gsl_vector_scale(m->freq, 1.0/ vector_sum(m->freq));
      read_freq=1;
    }
    if (strncmp(buffer, gap_tag, strlen(gap_tag))== 0){
      gsl_vector_fscanf(fin, m->gap_params);
    }        
    if (strncmp(buffer, scale_tag, strlen(scale_tag))== 0){
      if (fgets(buffer,MAXLINE, fin) <= 0) continue;
      sscanf(buffer, "%f", &scale);            
      if (read_matrix && read_freq) {
	m -> scale = scale;
	if (m->dim == Alphabet_size){
	  EvoModel_init_single_gap(m);
	}
	models[nmodel-1] = m;	
      }
    }
  }
  fclose(fin);
  return nmodel;
}


void EvoModel_write(FILE* fout, char* format, EvoModel* models[], int nmodel)
{  
  int i,j,k;
  if (fout == NULL) Die("Can't open file %s for writing  models\n"); 
  for(i=0; i < nmodel ; i++){
    fprintf(fout, "Model %d\n", i);
    fprintf(fout, "DIM\n");
    fprintf(fout, "%d\n", models[i]->dim);
    fprintf(fout, "Rate Matrix\n");
    for(j=0; j < models[i]->dim; j++){
      for(k=0; k < models[i]->dim; k++){
	fprintf(fout, format, gsl_matrix_get(models[i]->rm,j,k));	
      }
      fprintf(fout, "\n");
    }
    fprintf(fout, "Frequency\n");
    for(j=0; j < models[i]->dim; j++){
      fprintf(fout, format , gsl_vector_get(models[i]->freq, j));      
    }
    fprintf(fout, "\n");
    fprintf(fout, "Scale\n");
    fprintf(fout, "%d\n\n", models[i]->scale);    
  }
}


void alloc_node(PhyNode* node, int L)
{  
  node->tm= gsl_matrix_alloc(L, L);
  node->freq= gsl_vector_alloc(L);
  node->likelihood = gsl_vector_alloc(L);    
  int i=0;
  for(i=0; i < node->nchildren; i++){
    alloc_node(node->children[i], L);
  }
}


void clear_node(PhyNode* node)
{
  node->flag = 0;
  int i=0;
  for(i=0; i < node->nchildren; i++){
    clear_node(node->children[i]);
  }
}


void EvoModel_free(EvoModel *e)
{
  gsl_matrix_free(e->rm);
  gsl_vector_free(e->freq);
  if (e->gap_rm) gsl_matrix_free(e->gap_rm);
  if (e->gap_Q0) gsl_matrix_free(e->gap_Q0);
}




void get_tm_freq(PhyNode* node, EvoModel* m, gsl_matrix* tm, gsl_vector* freq, int use_gap)
{
  int i,j;
  //Set transition matrix  
  double* R;
  double* Q;
  gsl_matrix* rm;
  int     dim;
  rm = use_gap ? m->gap_rm : m->rm;
  dim = rm->size1;
  Q= Cal_Id(dim);  
  Condi_From_Rate(&R, rm->data, Q, node->length * m->scale, dim, 0, 0);  
  memcpy(tm->data, R, sizeof(double) * dim * dim);
  free(R);   
  
  /*
  if (fabs(node->length - 0.2) < 0.01){
    printf("length %f\n", node->length);
    printf("RM \n");
    matrix_fprintf(stdout, rm);
    matrix_fprintf(stdout, tm);
  }
  */
  
  gsl_vector_set_all(freq,0);
  if (!use_gap) gsl_vector_memcpy(freq, m->freq);
  else{
    if (m->dim ==Alphabet_size){
      for(i=0; i < Alphabet_size ; i++){
	gsl_vector_set(freq, i, gsl_vector_get(m->freq,i));
      }
    }
    else{
      for(i=0; i < Alphabet_size ; i++){
	for(j=0; j < Alphabet_size; j++){
	  gsl_vector_set(freq, i* (Alphabet_size + 1) + j, 
			 gsl_vector_get(m->freq,i*(Alphabet_size) + j));
	}
      }  
    }
  }
}

void init_node(PhyNode* node, EvoModel* m, int use_gap)
{ 
  int i;
  get_tm_freq(node, m, node->tm, node->freq, use_gap);  
  //printf("%s %f\n", node->name, node->length);
  //matrix_fprintf(stdout, node->tm);
  if (logscale){
    vector_log(node->freq);
    matrix_log(node->tm);
  }
  
  for(i=0; i < node->nchildren; i++){
    init_node(node->children[i], m, use_gap);
  }
}


void init_leaves_single(PhyNode** leaves, int nleaves, int* seqid2leave, int nseq, int* coll)
{
  int i,j,k;
  for(i=0; i < nleaves; i++){    
    gsl_vector_set_all(leaves[i]->likelihood, 1.0);
  }
  for(i=0; i < nseq; i++){
    int j = seqid2leave[i];    
    if (j != -1){
      gsl_vector_set_all(leaves[j]->likelihood, 0);
      gsl_vector_set(leaves[j]->likelihood, coll[i], 1.0);
    }
  }  
  for(i=0; i < nleaves; i++){    
    if (logscale) vector_log(leaves[i]->likelihood);
  }
}

void init_leaves_pair(PhyNode** leaves, int nleaves, int* seqid2leave, int nseq, int* coll, int* colr)
{
  int i,j;
  for(i=0; i < nleaves; i++){    
    gsl_vector_set_all(leaves[i]->likelihood, 1);
  }
  for(i=0; i < nseq; i++){
    int j = seqid2leave[i];
    if (j != -1){
      gsl_vector_set_all(leaves[j]->likelihood, 0);      
      int dim =  leaves[j]->likelihood->size;
      int idx = coll[i] * (Alphabet_size + 1) + colr[i];      
      gsl_vector_set(leaves[j]->likelihood, idx, 1.0);
    }    
  }
  for(i=0; i < nleaves; i++){    
    if (logscale) vector_log(leaves[i]->likelihood);
  }  
}



//Compute log likelihood bottom up
void calculate_all_likelihood(PhyNode** leaves, int nleaves)
{
  PhyNode** stack =(PhyNode**) malloc(sizeof(PhyNode*) * nleaves);
  PhyNode** sp    = stack;
  PhyNode* curr;    
  int      i; 
  gsl_vector * L = gsl_vector_alloc(leaves[0]->freq->size);
  
  for(i=0; i < nleaves; i++){    
    if (logscale) leaves[i]->loglikelihood= vector_vector_mul_log(leaves[i]->likelihood,leaves[i]->freq);
    else leaves[i]->loglikelihood=log(vector_vector_mul(leaves[i]->likelihood,leaves[i]->freq));
    PhyNode* p = leaves[i]->parent;
    p->flag ++;
    if (p->flag == p->nchildren){
      *(sp++) = p;
    }
  }


  
  while( sp > stack) {
    sp --;
    curr = *sp;    

    if (logscale) gsl_vector_set_all(curr->likelihood, 0);
    else gsl_vector_set_all(curr->likelihood, 1.0);

    for(i=0; i < curr->nchildren; i++){
      PhyNode* child = curr->children[i];
      if (logscale){
	matrix_vector_mul_log_eq(L, child->tm, child->likelihood);      
	gsl_vector_add(curr->likelihood, L);    	
      }
      else {
	matrix_vector_mul_eq(L, child->tm, child->likelihood);      
	gsl_vector_mul(curr->likelihood, L);    
      }
    }    

    
    if (logscale) curr->loglikelihood = vector_vector_mul_log(curr->likelihood, curr->freq);    
    else curr->loglikelihood = log(vector_vector_mul(curr->likelihood, curr->freq));    
    
    /*
    if (curr->likelihood->size > Alphabet_size + 1){
      printf("%s :\t %f\n", curr->name, curr->loglikelihood);    
      vector_fprintf(stdout, curr->likelihood);
      vector_fprintf(stdout, curr->freq);
    }
    */

    /* Posterior not interested 
    gsl_vector_memcpy(curr->posterior_freq, curr->freq);
    gsl_vector_mul(curr->posterior_freq, curr->likelihood);
    gsl_vector_scale(curr->posterior_freq, 1/likelihood);
    */
    PhyNode* p = curr->parent;    
    if (p){
      p->flag ++;
      if (p->flag == p->nchildren){
	*(sp++) = p;
      }
    }  
  }
  //printf("\n");
  gsl_vector_free(L);
  free(stack);
}


double null_loglikelihood(int nseq, int* col, gsl_vector* freq)
{
  int i;
  double loglik=0;
  for(i=0; i < nseq; i++){
    if (col[i] < freq->size){
      loglik += log(gsl_vector_get(freq,col[i]));	
    }
  }  
  return loglik;
}



/*
void init_node_pair(PhyNode* node, EvoModel* pair)
{  
  int i=0;
  get_tm_freq_pair(node, pair, node->tm, node->freq);  
  if (logscale){
    vector_log(node->freq);
    matrix_log(node->tm);
  }
  for(i=0; i < node->nchildren; i++){
    init_node_pair(node->children[i], pair);
  }
}

void init_node_pair_gap(PhyNode* node, EvoModel* mod)
{  
  int i;
  get_tm_freq_pair_gap(node, mod, node->tm, node->freq);
  
  if (logscale){
    vector_log(node->freq);
    matrix_log(node->tm);
  }
  for(i=0; i < node->nchildren; i++){
    init_node_pair_gap(node->children[i], gap_params, single_model_idx, pair_model_idx);
  }
}


void init_node_single_gap_pair_gap(PhyNode* node, EvoModel* single[], int nsingle, EvoModel* pair[], int npair)
{
  int i;
  int L = Alphabet_size;
  for(i=0; i < nsingle; i++){
    if (node->single_tm[i]== NULL){
      node->single_tm[i]= gsl_matrix_alloc(L + 1, L + 1);
    }
    if (node->single_freq[i]== NULL){
      node->single_freq[i]= gsl_vector_alloc(L + 1);
    }
    get_tm_freq_single_gap(node, single[i], node->single_tm[i], node->single_freq[i]);
  }
  for(i=0; i < npair; i++){  
    if (node->pair_tm[i]== NULL){
      node->pair_tm[i]= gsl_matrix_alloc(L * L, L * L);
    }
    if (node->pair_freq[i]== NULL){
      node->pair_freq[i]= gsl_vector_alloc(L * L);
    }
    get_tm_freq_pair(node, pair[i], node->pair_tm[i], node->pair_freq[i]);  
  }      

  for(i=0; i < node->nchildren; i++){
    init_node_single_gap_pair(node->children[i], single, nsingle, pair, npair);
  }
}
*/


/*

void get_tm_freq_pair_gap1(PhyNode* node, gsl_vector* gap_params, )
{
  
  int L = Alphabet_size;  
  int L5 = Alphabet_size+1;
  gsl_matrix_set_all(pair_gap_tm, 0);
  gsl_vector_set_all(pair_gap_freq, 0);

  gsl_matrix* pair_joint = diag_matrix_mul(pair_tm, pair_freq);
  gsl_matrix* single_joint = diag_matrix_mul(single_gap_tm, single_gap_freq);
    
  gsl_matrix_scale(pair_joint, 1.0/matrix_sum(pair_joint));
  gsl_matrix_scale(single_joint, 1.0/matrix_sum(single_joint));

  double beta0=gsl_vector_get(gap_params, 0);
  double q0=gsl_vector_get(gap_params, 1);
  double beta1=gsl_vector_get(gap_params, 2);
  double q1=gsl_vector_get(gap_params, 3);

  double e0 = exp(-beta0 * node->length);
  double p0 = (1 - e0);
  double r0 = e0;
  double e1 = exp(-beta1 * node->length);
  double delta1 = (1 - e1)/ ( 1 - q1* e1);  
  double p1 = delta1 * (1 - e1 + q1 * e1);      
  double r1 = (1 - 2 * delta1 + p1);

  delta1 *= r0;
  p1 *= r0;
  r1 *= r0;
  
  int i,j;
  int c1, c2, d1, d2;
  double unpair_sum =0 ;
  double p;
  //Fill the joint probability matrix
  for(i=0; i < L5 * L5; i++)
    for(j=0; j < L5 * L5; j++){
      c1 = i/L5;
      c2 = i%L5;
      d1 = j/L5;
      d2 = j%L5;
      int pair1 = (c1 == L && c2 == L) || (c1 < L && c2 < L);
      int pair2 = (d1 == L && d2 == L) || (d1 < L && d2 < L);

      if (pair1 && pair2) {
	if (c1 == L && c2 == L && d1 == L && d2 == L){
	  p = p1;
	}
	else if (c1 < L && c2 < L && d1 < L && d2 < L){
	  p = gsl_matrix_get(pair_joint, c1 * L + c2, d1 * L + d2) * r1;
	}
	else if (c1 < L && c2 < L){
	  p = gsl_vector_get(pair_freq, c1 * L + c2) * (delta1 - p1);
	}
	else if (d1 < L && d2 < L){
	  p = gsl_vector_get(pair_freq, d1 * L + d2) * (delta1 - p1);
	}	
      }
      else{
       p = gsl_matrix_get(single_joint, c1, d1) * gsl_matrix_get(single_joint, c2, d2); 
       unpair_sum += p;
      }
      gsl_matrix_set(pair_gap_tm, c1 *L5 + c2, d1 * L5 + d2, p);	      
    }

  //Normalize
  for(i=0; i < L5 * L5; i++)
    for(j=0; j < L5 * L5; j++){
      c1 = i/L5;
      c2 = i%L5;
      d1 = j/L5;
      d2 = j%L5;
      int pair1 = (c1 == L && c2 == L) || (c1 < L && c2 < L);
      int pair2 = (d1 == L && d2 == L) || (d1 < L && d2 < L);
      if (!pair1 || !pair2){
	p = gsl_matrix_get(pair_gap_tm, c1 *L5 + c2, d1 * L5 + d2);	      
	if (unpair_sum >0) p *= (p0/unpair_sum) ;	  
	gsl_matrix_set(pair_gap_tm, c1 *L5 + c2, d1 * L5 + d2, p);	      
      }
    }  

  //Get conditional matrix, and marginize frequency.
  double sum=0;
  for(i=0; i < L5 * L5; i++){
    p =0;
    for(j=0; j < L5 * L5; j++){
      p += gsl_matrix_get(pair_gap_tm, i, j);      
    }
    if (p > 0){
      for(j=0; j < L5 * L5; j++){
	gsl_matrix_set(pair_gap_tm, i,j, gsl_matrix_get(pair_gap_tm, i, j)/p);
      }
    }
    sum += p;
    gsl_vector_set(pair_gap_freq, i, p);    
  }
  if (fabs(node->length - 0.2)  < 0.05){
    printf("tree %s\t branch %f\n", node->name, node->length);    
    printf("single freq \n");
    vector_fprintf(stdout, single_gap_freq);
    printf("beta0 %f q0 %f \t beta1 %f q1 %f\n", beta0, q0, beta1, q0);
    printf("p0 %f r0 %f\t p1 %f delta1 %f r1 %f unpair %f\n",
	   p0, r0, p1, delta1,r1, p0/unpair_sum);
    
    printf("Freq: sum %f\n", sum);
    vector_fprintf(stdout, pair_gap_freq);  
  }
}


void get_tm_freq_single_gap1(PhyNode* node, EvoModel* m, gsl_matrix* tm, gsl_vector* freq)
{
  int i=0;
  //Set transition matrix  
  double* R;
  Condi_From_Rate(&R, m->gap_rm->data, m->gap_Q0->data, node->length * m->scale, Alphabet_size+1, 0, 0);  
  memcpy(tm->data, R, sizeof(double) * (Alphabet_size+1)*(Alphabet_size+1));  
  free(R);  
  //Set frequency
  double gap_beta= gsl_vector_get(m->gap_params,0);
  double gap_q0= gsl_vector_get(m->gap_params,1);
  double t = node->length * m->scale;

  double margin_delta_sum = 0;
  for(i=0; i < Alphabet_size; i++)
    margin_delta_sum += gsl_vector_get(m->freq, i)*gsl_matrix_get(tm, i, Alphabet_size);    
  double tmp = margin_delta_sum / (margin_delta_sum + 1 - gsl_matrix_get(tm, Alphabet_size, Alphabet_size));    
  
  double delta =  ( 1 - exp(-gap_beta * t))/ (1 - gap_q0 * exp(- gap_beta * t));

  //printf("Branch %f Single gap delta_tmp %f delta  %f\n", t, tmp, delta);  
  
  int L = m->dim;
  for(i=0; i < L; i++){
    gsl_vector_set(freq, i, gsl_vector_get(m->freq,i) * (1- delta));		   
  }
  gsl_vector_set(freq, L, delta);    
}

*/
