#include "phytree.h"
#include "grammar.h"
#include "funcs.h"
#include <float.h>
#include "/home/users/elfar/software/CMfinder_03/gsl/include/gsl/gsl_errno.h"
#include "/home/users/elfar/software/CMfinder_03/gsl/include/gsl/gsl_math.h"
#include "numerical_opt.h"


#include "squid.h"		/* general sequence analysis library    */
#include "msa.h"                /* squid's multiple alignment i/o       */
#include "structs.h"		/* data structures, macros, #define's   */
#include "funcs.h"		/* external functions                   */
#include "version.h"            /* versioning info for Infernal         */
#include "prior.h"
#include "../cmfinder/global.h"
#include <stdio.h>
#include <stdlib.h>

#define OUTBOUND -1


extern int logscale;

int   msa_num=0;
MSA* all_msa[1000];
int* all_seqid2leave[1000];
int* all_pairtable[1000];
int** all_col[1000];
EvoModel* mod;
PhyNode* root;
PhyNode** leaves;
int   nleaves;
int   pair_alphabet_size;

int   pairs[]={0*4 + 3, 3*4 + 0, 1*4 + 2, 2*4 + 1, 2*4 + 3, 3*4 + 2};

int pair_id(int id)
{
  int i;
  for(i=0; i < 6; i++){
    if (id == pairs[i]) return i;
  }
  return -1;
}


int map_pair_freq_param[4][4]={{0, 1, 3, 6},
			      {1, 2, 4, 7},
			      {3, 4, 5, 8},
                              {6, 7, 8, 9}};

int map_pair_sub_param[6][6]={{-1,  0, 1,  2,  4, 5},
			      { 0, -1, 2,  1,  5, 4},
			      { 1,  2,-1,  3,  5, 7},
			      { 2,  1, 3, -1,  7, 6},
			      { 4,  5, 6,  7, -1, 8},
			      { 5,  4, 7,  6,  8,-1}
};


#define PAIR_PAM_NUM 26
#define PAIR_FREQ_PAM_NUM 10
#define PAIR_SUB_PAM_NUM 9


double init_pair_params[PAIR_PAM_NUM]={
  0.0048, 0.0102, 0.0054, 0.0050, 0.1701, 0.0059, 0.1124, 0.0054, 0.0407, 0.0083, 
  4.3296, 0.2979, 20.4977, 3.2271,
  0.2096, 0.0314, 0.1926, 0.0175, 2.6891, 0.0129, 0.0009, 2.2260, 0.0422,
  0.020, 0.011, 0.088
};  

               
/* first 9 freq parameters (AU = UA). substitution: 15 cannonical basepair, 2 cannonical to non canonical
   (1/2 mutations), 1 noncanonical to noncanonical*/
int unpack_parameter_Evofold(const gsl_vector* params)
{
  int i,j;
  int k=0;
  double tot_freq=0;
  double subs1, subs2, non_subs1, non_subs2;
  for(i=0; i < params->size;i++){
    if (gsl_vector_get(params, k) < 0) return OUTBOUND;
  }
  for(i=0; i < Alphabet_size; i++) 
    for(j=0; j < Alphabet_size; j++){
      gsl_vector_set(mod->freq,i * Alphabet_size + j, gsl_vector_get(params,map_pair_freq_param[i][j]));	  
      tot_freq += gsl_vector_get(mod->freq,i * Alphabet_size + j);
    }

  gsl_vector_scale(mod->freq, 1/tot_freq);  
  k += PAIR_FREQ_PAM_NUM;
  subs1 = gsl_vector_get(params, k++);
  subs2 = gsl_vector_get(params, k++);
  non_subs1= gsl_vector_get(params,k++);
  non_subs2= gsl_vector_get(params, k++);
  
  for(i=0; i < pair_alphabet_size; i++){
    double tot=0;
    for(j=0; j < pair_alphabet_size; j++){
      if (i== j) continue;
      int l1,l2,p;
      l1 = pair_id(i);
      l2 = pair_id(j);
      if (l1>= 0 && l2 >=0){
	p = map_pair_sub_param[l1][l2];
	gsl_matrix_set(mod->rm, i, j, gsl_vector_get(params,k+p) * gsl_vector_get(mod->freq,j));
      }
      else{
	int c1, c2,d1,d2;
	double p ;
	c1 = i/Alphabet_size;
	c2 = i%Alphabet_size;
	d1 = j/Alphabet_size;
	d2 = j%Alphabet_size;
	
	int subs = 2;	
	if (c1 == d1 || c2 == d2){
	  subs = 1;
	}
	if (l1 >=0 || l2 >= 0){
	  p = (subs ==1 ? subs1 : subs2);
	}
	else{
	  p = (subs ==1 ? non_subs1 : non_subs2);
	}
	double v =  gsl_vector_get(mod->freq,j) * p;
	gsl_matrix_set(mod->rm, i, j, v);	  
	/*
	printf("%d %c%c %d - %d %c%c %d \t subs %d p %f, freq %f %f\n",
	       i, Alphabet[c1],Alphabet[c2], l1,
	       j, Alphabet[d1],Alphabet[d2], l2,
	       subs, p, gsl_vector_get(mod->freq,j), v);	       
	*/
      }
      if (gsl_matrix_get(mod->rm, i,j)> 10){
	printf("Error\n");
      }
      tot += gsl_matrix_get(mod->rm, i, j);
    }
    gsl_matrix_set(mod->rm, i, i, -tot);
  }

  k+= PAIR_SUB_PAM_NUM;
  for(i=0; i < mod->gap_params->size; i++){
    gsl_vector_set(mod->gap_params, i, gsl_vector_get(params, k+i));
  }
  
  printf("Freq\n");
  vector_fprintf(stdout, mod->freq);
  printf("RM\n");
  matrix_fprintf(stdout, mod->rm);
  printf("Gap\n");
  vector_fprintf(stdout, mod->gap_params);
  
  return 1;
}


double evaluate_pair(gsl_vector* params, void* other_params)
{
  int i,j,k;
  int status= unpack_parameter_Evofold(params);
  if (status == OUTBOUND){
    return FLT_MAX;
  }
  EvoModel_init_pair_gap(mod);
  printf("params :\n");
  for(i=0; i < params->size; i++){
    printf("%.3f ", gsl_vector_get(params, i));	   
  }

  /*
  printf("\nfreq :\n");
  for(i=0; i < mod->freq->size; i++){
    printf("%.3f ", gsl_vector_get(mod->freq, i));	   
  }
  printf("\n");
  
  printf("Rate \n");
  matrix_fprintf(stdout, mod->rm);
  printf("\n");
  */

  init_node(root, mod, 1);
  
  double loglikelihood=0;
  int   tot_pair=0;
  for(i=0; i < msa_num; i++){
    for(j=0; j < all_msa[i]->alen; j++){
      k = all_pairtable[i][j] ;
      if (k < j) continue;     
      clear_node(root);
      tot_pair ++;
      init_leaves_pair(leaves, nleaves, all_seqid2leave[i], all_msa[i]->nseq, all_col[i][j], all_col[i][k]);
      calculate_all_likelihood(leaves, nleaves);
      loglikelihood+=root->loglikelihood;
    }
  }
  printf("Loglikelihood %.3f\n", loglikelihood);
  return -loglikelihood * 100/ tot_pair ;
}

int main(int argc, char* argv[])
{
  char* tree_file = argv[1];
  char* ali_files = argv[2];
  
  char  buffer[MAXLINE];
  EvoModel* tmp_model[1];

  root = ReadPhyFile(tree_file);
  leaves = PhyLeaves(root);  
  nleaves = root->size;
  pair_alphabet_size = Alphabet_size * Alphabet_size;
  int   format =  MSAFILE_STOCKHOLM;  
  FILE* file_list;
  if (!(file_list = fopen(ali_files, "r"))) Die("");
  MSAFILE     *afp = NULL;        /* file handle of initial alignment */
  MSA  *msa;
  int  i,j;

  gsl_vector* pair_count= gsl_vector_alloc(pair_alphabet_size);
  gsl_vector* tmp_count= gsl_vector_alloc(pair_alphabet_size);
  gsl_vector_set_all(pair_count, 0);
  int total_pair=0;
  while(fgets(buffer, MAXLINE, file_list)){    
    char filename[MAXLINE];
    sscanf(buffer, "%s", filename);
    if ((afp = MSAFileOpen(filename, format, NULL)) == NULL)
      Die("Alignment file %s could not be opened for reading",filename);
    if ((msa = MSAFileRead(afp)) != NULL){
      MSAFileClose(afp);
    }
    else{
      Die("Error reading Alignment file %s", filename);
    }    
    all_msa[msa_num]= msa;
    all_seqid2leave[msa_num] = (int*) MallocOrDie(sizeof(int) * msa->nseq);    
    for(i=0; i < msa->nseq; i++){
      all_seqid2leave[msa_num][i] = -1;
      for(j=0; j < nleaves; j++){      
	if ( strncmp(msa->sqname[i], leaves[j]->name, strlen(leaves[j]->name)) == 0 ){
	  all_seqid2leave[msa_num][i]=j;
	  break;
	}
      }
    }  
    all_pairtable[msa_num] = GetPairtable(all_msa[msa_num]->ss_cons);
    all_col[msa_num] = (int**)MallocOrDie(sizeof(int*) * msa->alen);    
    memset(all_col[msa_num], 0, sizeof(int*) * msa->alen);    
    for(i=0; i < msa->alen;i++){
      all_col[msa_num][i] = (int*) MallocOrDie(sizeof(int) * msa->nseq);
      for(j=0; j < msa->nseq; j++){
	all_col[msa_num][i][j] = SymbolIndex(msa->aseq[j][i]);
	if (all_col[msa_num][i][j] > Alphabet_size){
	  all_col[msa_num][i][j] = Alphabet_size;
	}
      }
    }
    msa_num++;
  }

  mod = EvoModel_alloc(PAIR, 1);
  
  gsl_vector_scale(pair_count, 1.0/total_pair);
  
  gsl_vector* v = vector_init(init_pair_params, PAIR_PAM_NUM);
  gsl_vector* ub = gsl_vector_alloc(PAIR_PAM_NUM);
  gsl_vector* lb = gsl_vector_alloc(PAIR_PAM_NUM);

  gsl_vector_set_all(ub, 100);
  for(i=0; i < PAIR_FREQ_PAM_NUM; i++){
    gsl_vector_set(ub, i, 0.5);
  }
  gsl_vector_set_all(lb, 0.0001);
  
  
  double loglikelihood;

  alloc_node(root, (Alphabet_size + 1) * ( Alphabet_size + 1));
  opt_bfgs(evaluate_pair, v, NULL, &loglikelihood, lb, ub, stdout, NULL,  OPT_HIGH_PREC, NULL);	   
  
  printf("params :\n");
  for(i=0; i < v->size; i++){
    printf("%.4f ", gsl_vector_get(v, i));	   
  }
  printf("\nfreq :\n");
  for(i=0; i < mod->freq->size; i++){
    printf("%.4f ", exp(gsl_vector_get(mod->freq, i)));	   
  }
  printf("\n");
  for(i=0; i < mod->rm->size1; i++){
      for(j=0; j < mod->rm->size2; j++){
	printf("%.4f ", gsl_matrix_get(mod->rm, i,j));	   	
      }
      printf("\n");
  }

}



      
