#include "phytree.h"
#include "grammar.h"
#include "funcs.h"
#include <float.h>
#include "/people/disk1/elfar/software/CMfinder_03/gsl/include/gsl/gsl_errno.h"
#include "/people/disk1/elfar/software/CMfinder_03/gsl/include/gsl/gsl_math.h"
//#include <gsl/gsl_errno.h>
//#include <gsl/gsl_math.h>
#include "numerical_opt.h"


#include "squid.h"		/* general sequence analysis library    */
#include "msa.h"                /* squid's multiple alignment i/o       */
#include "structs.h"		/* data structures, macros, #define's   */
#include "funcs.h"		/* external functions                   */
#include "version.h"            /* versioning info for Infernal         */
#include "prior.h"
#include "../cmfinder/global.h"
#include <stdio.h>
#include <stdlib.h>

#define OUTBOUND -1


extern int logscale;

int   msa_num=0;
MSA* all_msa[1000];
int* all_seqid2leave[1000];
int   all_col_num=0;
int** all_col[1000];
int** conserved_cols[1000];
int** nonconserved_cols[1000];
EvoModel* mod;
PhyNode* root;
PhyNode** leaves;
int   nleaves;

int map_single_sub_param[4][4]={{-1, 0, 1,  3},
				{ 0,-1, 2,  4},
				{ 1, 2,-1,  5},
				{ 3, 4, 5, -1}};

#define SINGLE_PAM_NUM 11
//double init_single_params[SINGLE_PAM_NUM] = {0.25, 0.25, 0.25, 0.25, 1, 1, 1, 1, 1, 1};
//double init_single_params[SINGLE_PAM_NUM] = {0.25, 0.25, 0.35, 0.15, 0.189, 0.44, 0.26, 0.23, 0.72, 0.21};
//double init_single_params[SINGLE_PAM_NUM] = {0.001, 0.001, 1.685, 1.626, 0.000, 12.117, 4.485, 12.800, 4.725, 0.000};
double init_single_params[SINGLE_PAM_NUM] = {0.239, 0.237, 0.261, 0.269,
					     0.1, 
					     0.387, 0.946, 0.323, 0.452, 0.986, 0.349};

					     

double get_entropy(int *dsq, int num)
{
  int    i, idx;  
  double total_e, e;  
  double f[Alphabet_size];        /* singlet frequency vector            */
  int    sym, c;

  for (sym = 0; sym <Alphabet_size; sym++)
    f[sym] = 0;    
  
  for (idx = 0; idx < num; idx++){
    if (dsq[idx]  >= Alphabet_size){
      for (sym = 0; sym < Alphabet_size; sym++)
	f[sym] += 1.0/Alphabet_size;
    }
    else{	
      sym = dsq[idx];
      f[sym] += 1;
    }    
  }
  e = 0;    
  for (sym = 0; sym < Alphabet_size; sym++){
    f[sym] /= num;            
    if (f[sym] > 0) 
      e += - f[sym] * log(f[sym]);            
  }    
  return e;  
}

double get_gapcount(int *dsq, int num)
{
  int idx;
  int count = 0;  
  for (idx = 0; idx < num; idx++){
    if (dsq[idx] >= Alphabet_size){
      count++;
    }
  }
  return count;
}


int unpack_single_param(gsl_vector* params)
{
  double tot_freq=0;
  int i,j, k=0;
  for(i=0; i < Alphabet_size; i++){
    gsl_vector_set(mod->freq, i, gsl_vector_get(params, k++));
    tot_freq+= gsl_vector_get(mod->freq, i);    
  }  
  gsl_vector_scale(mod->freq, 1/tot_freq);

  gsl_vector_set(mod->gap_params, 0, gsl_vector_get(params, k++));
  //gsl_vector_set(mod->gap_params, 1, gsl_vector_get(params, k++));
  
  gsl_matrix_set_all(mod->rm, 0);
  for(i=0; i < Alphabet_size; i++) {
    double tot_subs=0;
    for(j=0; j < Alphabet_size; j++){
      if (i==j) continue;
      gsl_matrix_set(mod->rm, i, j, gsl_vector_get(mod->freq, j) * gsl_vector_get(params, k+ map_single_sub_param[i][j]));	  
      tot_subs += gsl_matrix_get(mod->rm, i,j);
    }
    gsl_matrix_set(mod->rm, i, i, - tot_subs);    
  }
  mod->scale = 1;
  EvoModel_init_single_gap(mod);
  
  printf("\nInit freq :\n");
  for(i=0; i < mod->freq->size; i++){
    printf("%.3f ", gsl_vector_get(mod->freq, i));	   
  }
  printf("\nRM\n");
  for(i=0; i < mod->rm->size1; i++){
    for(j=0; j < mod->rm->size2; j++){
      printf("%.3f ", gsl_matrix_get(mod->rm, i,j));	   	
    }
    printf("\n");
  }
  return 1;
}



double evaluate(gsl_vector* params, void* other_params)
{
  int i,j,k;
  int*** cols = (int***)other_params;  
  if (!unpack_single_param(params)){
    return  INF;
  }

  printf("params :\n");
  for(i=0; i < params->size; i++){
    printf("%.3f ", gsl_vector_get(params, i));	   
  }
  printf("\n");

  init_node(root,mod, 1);
  double loglikelihood=0;
  int tot_col=0;
  for(i=0; i < msa_num; i++){
    for(j=0; j < all_msa[i]->alen; j++){
      //printf("MSA %d col %d\n", i,j);
      if (cols[i][j]==NULL)continue;
      tot_col++;
      clear_node(root);
      init_leaves_single(leaves, nleaves, all_seqid2leave[i], all_msa[i]->nseq, cols[i][j]);
      calculate_all_likelihood(leaves, nleaves);
      loglikelihood+=root->loglikelihood;
    }
  }

  printf("tot_col %d Loglikelihood %.3f func %.3f\n", tot_col, loglikelihood, -loglikelihood *100 / tot_col );
  return -loglikelihood *100 / tot_col ;
}

int doublecmp(const void* a, const void* b)
{
  double* da = (double*)a;
  double* db = (double*)b;
  if (*da < *db) return -1 ;
  if (*da > *db) return 1;
  return 0;
}

void collect_cols(double conserved_quantile, double nonconserved_quantile)
{  
  int i,j;
  double* entropy= (double*)MallocOrDie(sizeof(double) * all_col_num);
  int tot_col=0;
  int k=0;
  double conserved_freq[Alphabet_size];
  double nonconserved_freq[Alphabet_size];
  memset(conserved_freq, 0, sizeof(double)*Alphabet_size);
  memset(nonconserved_freq, 0, sizeof(double)*Alphabet_size);  

  double** all_entropy=(double**)MallocOrDie(sizeof(double*) * msa_num);
  double conserved_thresh, nonconserved_thresh;
  for(i=0; i < msa_num; i++){
    conserved_cols[i] = (int**)MallocOrDie(sizeof(int*) * all_msa[i]->alen);
    nonconserved_cols[i] = (int**)MallocOrDie(sizeof(int*) * all_msa[i]->alen);
    all_entropy[i] = (double*)MallocOrDie(sizeof(double) * all_msa[i]->alen);
    for(j=0; j < all_msa[i]->alen; j++){
      all_entropy[i][j] = -1;
      conserved_cols[i][j] = NULL;
      nonconserved_cols[i][j] = NULL;
      if (isgap(all_msa[i]->rf[j])) continue;
      all_entropy[i][j] = get_entropy(all_col[i][j], all_msa[i]->nseq);
      entropy[k++] = all_entropy[i][j];
    }
  }
  qsort(entropy, k, sizeof(double), doublecmp);  
  
  conserved_thresh = 0.4;
  //conserved_thresh = entropy[(int) (k*conserved_quantile)];
  //nonconserved_thresh = entropy[(int) (k* nonconserved_quantile)];
  nonconserved_thresh = 0.6;
  
  int num_conserved_col=0;
  int num_nonconserved_col=0;
  int num_conserved_nt=0;
  int num_nonconserved_nt=0;

  for(i=0; i < msa_num; i++){
    for(j=0; j < all_msa[i]->alen; j++){            
      if (all_entropy[i][j] >= 0 && all_entropy[i][j] <= conserved_thresh){
	conserved_cols[i][j] = all_col[i][j];
	num_conserved_col++;
	num_conserved_nt += all_msa[i]->nseq;
	for(k=0; k < all_msa[i]->nseq; k++){
	  if (all_col[i][j][k] < Alphabet_size){
	    conserved_freq[all_col[i][j][k]]++;
	  }
	  else{
	    int l=0;
	    for(l=0; l <Alphabet_size; l++){
	      nonconserved_freq[l] += 0.25;
	    }
	  }
	}
      }
      else if (all_entropy[i][j] > nonconserved_thresh){
	nonconserved_cols[i][j] = all_col[i][j];
	num_nonconserved_col++;
	num_nonconserved_nt += all_msa[i]->nseq;
	for(k=0; k < all_msa[i]->nseq; k++){
	  if (all_col[i][j][k] < Alphabet_size){
	    nonconserved_freq[all_col[i][j][k]]++;
	  }
	  else{
	    int l=0;
	    for(l=0; l <Alphabet_size; l++){
	      nonconserved_freq[l] += 0.25;
	    }
	  }
	}
      }            
    }
  }  

  printf("conserved col %d threshold %f \t Freq ", num_conserved_col, conserved_thresh);
  for(i=0; i < Alphabet_size; i++){
    printf("%4f ", conserved_freq[i]/num_conserved_nt);
  }
  printf("\n");
  printf("nonconserved col %d treshold %f\t Freq ", num_nonconserved_col, nonconserved_thresh);
  for(i=0; i < Alphabet_size; i++){
    printf("%4f ", nonconserved_freq[i]/num_nonconserved_nt);
  }
  printf("\n");
}




int main(int argc, char* argv[])
{
  char* tree_file = argv[1];
  char* ali_files = argv[2];

  char  buffer[MAXLINE];

  root = ReadPhyFile(tree_file);
  leaves = PhyLeaves(root);  
  nleaves = root->size;
  int   format =  MSAFILE_STOCKHOLM;  
  FILE* file_list;
  if (!(file_list = fopen(ali_files, "r"))) Die("");
  MSAFILE     *afp = NULL;        /* file handle of initial alignment */
  MSA  *msa;
  int  i,j;

  while(fgets(buffer, MAXLINE, file_list)){    
    char filename[MAXLINE];
    sscanf(buffer, "%s", filename);
    if ((afp = MSAFileOpen(filename, format, NULL)) == NULL)
      Die("Alignment file %s could not be opened for reading",filename);
    if ((msa = MSAFileRead(afp)) != NULL){
      MSAFileClose(afp);
    }
    else{
      Die("Error reading Alignment file %s", filename);
    }    
    all_msa[msa_num]= msa;
    all_seqid2leave[msa_num] = (int*) MallocOrDie(sizeof(int) * msa->nseq);    
    for(i=0; i < msa->nseq; i++){
      all_seqid2leave[msa_num][i] = -1;
      for(j=0; j < nleaves; j++){      
	if ( strncmp(msa->sqname[i], leaves[j]->name, strlen(leaves[j]->name)) == 0 ){
	  all_seqid2leave[msa_num][i]=j;
	  break;
	}
      }
    }  
    all_col[msa_num] = (int**)MallocOrDie(sizeof(int*) * msa->alen);    
    memset(all_col[msa_num], 0, sizeof(int*) * msa->alen);    
    for(i=0; i < msa->alen;i++){
      all_col[msa_num][i] = NULL;
      if (isgap(msa->rf[i])) {
	continue;
      }
      all_col[msa_num][i] = (int*) MallocOrDie(sizeof(int) * msa->nseq);
      for(j=0; j < msa->nseq; j++){
	all_col[msa_num][i][j] = SymbolIndex(toupper(msa->aseq[j][i]));
	if (all_col[msa_num][i][j] > Alphabet_size){
	  all_col[msa_num][i][j] = Alphabet_size;
	}
      }
      all_col_num += msa->alen;
    }
    msa_num++;
  }
  
  mod = EvoModel_alloc(SINGLE, 1);
  mod->dim = Alphabet_size;
  gsl_vector* v= vector_init(init_single_params, SINGLE_PAM_NUM);
  /*
  for(i=0; i < nleaves; i++){
    leaves[i]->likelihood = gsl_vector_alloc(Alphabet_size);
  }
  for(i=0; i < msa_num; i++)
    for(j=0; j < all_msa[i]->alen; j++)      
      init_col_leaves(leaves, nleaves, all_seqid2leave[i], all_msa[i]->nseq, all_col[i][j]);
  */
  //collect_cols(0.3, 0.7);
  double loglikelihood;
  gsl_vector* ub = gsl_vector_alloc(SINGLE_PAM_NUM);
  gsl_vector* lb = gsl_vector_alloc(SINGLE_PAM_NUM);
  gsl_vector_set_all(ub, 0.5);
  gsl_vector_set_all(lb, 0);
  alloc_node(root, Alphabet_size+1);

  opt_bfgs(evaluate, v, all_col, &loglikelihood, lb, ub, stdout, NULL,  OPT_HIGH_PREC, NULL);    
  
  printf("params :\n");
  for(i=0; i < v->size; i++){
    printf("%.3f ", gsl_vector_get(v, i));	   
  }
  printf("\nfreq :\n");
  for(i=0; i < mod->freq->size; i++){
    printf("%.3f ", gsl_vector_get(mod->freq, i));	   
  }
  printf("\n");
  for(i=0; i < mod->rm->size1; i++){
    for(j=0; j < mod->rm->size2; j++){
	printf("%.3f ", gsl_matrix_get(mod->rm, i,j));	   	
    }
    printf("\n");
  }

  printf("Loglikelihood %f\n", loglikelihood); 
}



      
