#include "phytree.h"
#include "grammar.h"
#include "funcs.h"
#include <float.h>

#include "squid.h"		/* general sequence analysis library    */
#include "msa.h"                /* squid's multiple alignment i/o       */
#include "structs.h"		/* data structures, macros, #define's   */
#include "funcs.h"		/* external functions                   */
#include "version.h"            /* versioning info for Infernal         */
#include "prior.h"
#include "../cmfinder/global.h"

/*
int IsBasePair(char l, char r){
  if (l=='A' && r == 'U' ||
      l=='U' && r == 'A' ||
      l=='G' && r == 'C' ||
      l=='C' && r == 'G' ||
      l=='G' && r == 'U' ||
      l=='U' && r == 'G'){
    return 1;
  }
  return 0;
}
*/

static double my_log(double d)
{
  return d < exp(-10) ? -10: log(d);
}

double** mxy(MSA* msa)
{
  int     i, j, idx; 
  double  fx[Alphabet_size];        /* singlet frequency vector            */
  double  fy[Alphabet_size];	/* another singlet frequency vector    */
  double  fxy[Alphabet_size][Alphabet_size]; /* pairwise frequency 2D array    */
  int     symi, symj;		/* counters for symbols                */
  float   pairs;		/* counter for pairs in which there are no gaps */  
  double  total_mxy=0;
  double  mxy;  
  
  double  prior[4]={0.25, 0.25, 0.25, 0.25};
  double** mat= (double**)MallocOrDie(sizeof(double*) * msa->alen);
  for(i=0; i < msa->alen; i++){
    mat[i] = (double*)MallocOrDie(sizeof(double) * msa->alen);
    memset(mat[i], 0, sizeof(double) * msa->alen);    
  }
  char** dsq = DigitizeAlignment(msa->aseq, msa->nseq, msa->alen);
  char*  cs = msa->ss_cons;
  int* pt = GetPairtable(cs);  

  for (i = 0; i < msa->alen; i++) {
    if (pt[i] <= i) continue;
    j = pt[i];
    /* zero counter array */
    for (symj = 0; symj < Alphabet_size; symj++)
      {
	fx[symj] = fy[symj] = 0.0;
	for (symi = 0; symi < Alphabet_size; symi++)
	  fxy[symj][symi] = 0.0;
      }
    /* count symbols in a column */
    for (idx = 0; idx < msa->nseq; idx++)
      {
	if (dsq[idx][i+1]== DIGITAL_GAP){
	  if (dsq[idx][j+1] == 	 DIGITAL_GAP ){
	      for (symi = 0; symi < Alphabet_size; symi++){		
		for (symj = 0; symj < Alphabet_size; symj++)
		  fxy[symj][symi] += prior[symj] * prior[symi] ;
		fx[symi] += prior[symi] ;
		fy[symi] += prior[symi] ;
	      }	      
	    }
	    else{     
	      symj = dsq[idx][j+1];	
	      fy[symj] += 1;
	      for (symi = 0; symi < Alphabet_size; symi++){		
		fxy[symj][symi] += prior[symi] ;
		fx[symi] += prior[symi] ;
	      }	      
	    }
	}
	else{	  	    	
	  symi = dsq[idx][i+1];
	  fx[symi] += 1;
	  if (dsq[idx][j+1] ==  DIGITAL_GAP ){	      
	    for (symj = 0; symj < Alphabet_size; symj++){
	      fxy[symj][symi] += prior[symj] ;
	      fy[symj] += prior[symj] ;
	    }
	  }
	  else{     
	    symj = dsq[idx][j+1];	
	    fxy[symj][symi] += 1;
	    fy[symj] += 1;
	  }	  	  	  
	}	
      }    
    /* convert to frequencies */
    for (symi = 0; symi < Alphabet_size; symi++)
      {
	fx[symi] /=  msa->nseq;
	fy[symi] /=  msa->nseq;
	for (symj = 0; symj < Alphabet_size; symj++)
	  fxy[symj][symi] /=  msa->nseq;
      } 	       	  
    /* calculate mxy. 144.269504 is a conversion of ln's into
     * bits * 100: i.e. 100 * (1/log(2)) 
     */
    mxy = 0;
    for (symi = 0; symi < Alphabet_size; symi++)
      for (symj = 0; symj < Alphabet_size; symj++)
	{
	  if (fxy[symj][symi] > 0.0)
	    mxy +=  1.44269504 * fxy[symj][symi] *
	      log((fxy[symj][symi] / (fx[symi] * fy[symj])));	    	      
	}	
    if (mxy < -0.00001){
      Die("Error ! Column %d  %d mxy = %f", i, j, mxy);
    }        
    mat[i][j] = mxy;
  }  
  return mat;
}



void geo_mean_bppr(char    **aseq,            /* array of aligned sequences, flushed right  */
		   int       nseq,		/* number of aligned sequences */
		   int       alen,		/* length of each sequence (all the same)  */
		   double  **bp_pr,
		   float*    weights,    
		   double ***ret_bppr)        /* RETURN: bppr array           */
{
  int    **lod;
  double **bppr;
  char   *nogap_seq, *structure;
  int    *idx_map;
  int    k,i,j, i1, j1;
  double tot_weight=0;
  double p;
  bppr = DoubleAlloc2DArray(alen+1);

  for(j=1; j < alen+1; j++){
    for(i=0; i <j ; i++)
      bppr[j][i] = 0;
  }
  for(k=0;k< nseq; k++) tot_weight+= weights[k];

  if (bp_pr==NULL){
    bp_pr = (double**)MallocOrDie(sizeof(double*) * nseq);
    memset(bp_pr, 0, sizeof(double**) * nseq); 
  }
  for (k = 0; k < nseq; k++) {
    nogap_seq = remove_gap(aseq[k], &idx_map);
    if (bp_pr[k]==NULL){
      bp_pr[k]= bppr_seq(nogap_seq);
    }
    if (bp_pr[k] == NULL) continue;    
    for(j = 1; j < alen; j++)
      for(i = 0; i < j; i++){
	i1 = idx_map[i];
	j1 = idx_map[j];
	if (i1 >=0 && j1 >= 0){
	  p = bp_pr[k][TriIndex(i1,j1)];
	  if (p < 0.05) p = 0.05;
	  bppr[j+1][i+1] +=  log(p) * weights[k]/tot_weight;
	}
      }
    free(nogap_seq);
    free(idx_map);
  }
  *ret_bppr = bppr;
}



void trim_gap_cols(MSA* msa, double gapthreshold)
{
  int i,j;
  int pos = 0;
  int gapcount=0;
  for(i=0; i < msa->alen; i++){
    gapcount=0;
    for(j=0; j < msa->nseq; j++){
      if (isgap(msa->aseq[j][i])){
	gapcount++;
      }
    }
    if (gapcount <= (int)( gapthreshold * msa->nseq) ||(msa->ss_cons && !isgap(msa->ss_cons[i]))){
      if (pos != i){
	for(j=0; j < msa->nseq; j++){
	  msa->aseq[j][pos] = msa->aseq[j][i];
	  if (msa->ss && msa->ss[j]){
	    msa->ss[j][pos] = msa->ss[j][i];
	  }
	}
	if (msa->rf){
	  msa->rf[pos] = msa->rf[i];	
	}
	if (msa->ss_cons){
	  msa->ss_cons[pos] = msa->ss_cons[i];
	}
      }
      pos++;
    }
  }
  if (pos != i){
    for(j=0; j < msa->nseq; j++){
      msa->aseq[j][pos] = '\0';
      if (msa->ss && msa->ss[j]){
	msa->ss[j][pos] = '\0';
      }
    }
    if (msa->rf) msa->rf[pos] = '\0';
    if (msa->ss_cons) msa->ss_cons[pos] = '\0';
  }
  msa->alen = pos;
}

static struct opt_s OPTIONS[] = {
  { "-t", TRUE, sqdARG_STRING},
  { "-s", TRUE, sqdARG_STRING},
  { "-p", TRUE, sqdARG_STRING},
  { "-m", TRUE, sqdARG_STRING},
  { "--partition", FALSE, sqdARG_NONE},
  { "--informat", FALSE, sqdARG_STRING}
}
;
#define NOPTIONS (sizeof(OPTIONS) / sizeof(struct opt_s))

static char usage[] = "\
Usage: posterior [options] <alignment> \n\
where options are:\n";
  

int main(int argc, char* argv[])
{
  char path[100];
  char tree_file[100];
  char single_file[100] ;
  char pair_file[100];
  char grammar_file[100];
  char ali_file[100];

  strcpy(path, getenv("Models"));
  strcat(path, "/data/");
  strcpy(tree_file, path);
  strcat(tree_file,"tree.newick");
  strcpy(single_file, path);
  strcat(single_file,"single_model");
  strcpy(pair_file, path);
  strcat(pair_file,"pair_model");
  strcpy(grammar_file, path);
  strcat(grammar_file,"scfg");
  
  char    *optname;                /* name of option found by Getopt()        */
  char    *optarg;                 /* argument found by Getopt()              */
  int     optind;                 /* index in argv[]                         */	

  int     use_partition =0;
  int   format =  MSAFILE_STOCKHOLM;

  while (Getopt(argc, argv, OPTIONS, NOPTIONS, usage,
                &optind, &optname, &optarg))  {
    if      (strcmp(optname, "-t") == 0) {
      strcpy(tree_file, optarg);
    }  
    else if (strcmp(optname, "-s")== 0){
      strcpy(single_file,optarg);
    }    
    else if (strcmp(optname, "-p")== 0){
      strcpy(pair_file,optarg);
    }    
    else if (strcmp(optname, "-g")== 0){
      strcpy(grammar_file, optarg);
    }    
    else if (strcmp(optname, "--partition")==0){
      use_partition = 1;
    }
    else if (strcmp(optname, "--informat")==0){
      format = String2SeqfileFormat(optarg);
      if (format == MSAFILE_UNKNOWN)
        Die("unrecognized sequence file format \"%s\"", optarg);
      if (! IsAlignmentFormat(format))
        Die("%s is an unaligned format, can't read as an alignment", optarg);
    }
  }  
  strcpy(ali_file,argv[argc-1]);

  MSAFILE     *afp = NULL;        /* file handle of initial alignment */
  MSA  *msa;
  int** cols;
  PhyNode* root;
  PhyNode** leaves;
  int   nleaves;
  int*  seqid2leave;
  int   i,j,k;
  Grammar* g;
  char* rf;
  
  g = read_grammar(grammar_file);

  if ((afp = MSAFileOpen(ali_file, format, NULL)) == NULL)
    Die("Alignment file %s could not be opened for reading", ali_file);
  if ((msa = MSAFileRead(afp)) != NULL){
    MSAFileClose(afp);
  }
  else{
    Die("Error reading Alignment file %s", ali_file);
  }

  for(k=0; k < msa->nseq; k++){
    if (msa->ss && msa->ss[k]){
      int* pt = GetPairtable(msa->ss[k]);      
      for (i = 0; i < msa->alen; i++) {
	if (pt[i] <= i) continue;
	j = pt[i];
	if (!IsBasePair(msa->aseq[k][i], msa->aseq[k][j])){
	  msa->ss[k][i]='.';
	  msa->ss[k][j]='.';
	}
      }
    }
  }
  
  root = ReadPhyFile(tree_file);
  leaves = PhyLeaves(root);  
  nleaves = root->size;
  seqid2leave = (int*)MallocOrDie(sizeof(int) * msa->nseq);  
  for(i=0; i < msa->nseq; i++){
    seqid2leave[i] = -1;
    for(j=0; j < nleaves; j++){      
      if ( strncmp(msa->sqname[i], leaves[j]->name,strlen(leaves[j]->name)) == 0 ){
	seqid2leave[i]=j;
	break;
      }
    }
  }  
  
  trim_gap_cols(msa, 0.7);
  cols = (int**) MallocOrDie(sizeof(int*) * msa->alen);
  for(i=0; i < msa->alen;i++){
    cols[i] = (int*) MallocOrDie(sizeof(int) * msa->nseq);
    for(j=0; j < msa->nseq; j++){
      int c= SymbolIndex(msa->aseq[j][i]);
      if (c > Alphabet_size) c= Alphabet_size;
      cols[i][j] = c < Alphabet_size ? c: Alphabet_size ;
    }
  }

  
  single_model_num = EvoModel_read(single_file, single_model);
  pair_model_num = EvoModel_read(pair_file, pair_model);


  for(i=0; i < single_model_num ; i++){
    single_emission[i]= (double*)MallocOrDie(sizeof(double)*msa->alen);  
  }
  for(i=0; i < pair_model_num; i++){
    pair_emission[i]= (double**)MallocOrDie(sizeof(double*)*msa->alen);  
    for(j=0; j < msa->alen; j++){
      pair_emission[i][j] = malloc(sizeof(double)*msa->alen);  
    }
  }    

  //Conserved model
  alloc_node(root, Alphabet_size + 1);
  for(k = 0; k < single_model_num; k++){
    EvoModel_init_single_gap(single_model[k]);
    init_node(root, single_model[k],1);
    for(i=0; i < msa->alen; i++){      
      clear_node(root);
      init_leaves_single(leaves, nleaves, seqid2leave, msa->nseq, cols[i]);      
      calculate_all_likelihood(leaves, nleaves);
      single_emission[k][i]=root->loglikelihood;
    }
  }
  
  //NULL model
  //single_emission[single_model_num]= (double*)MallocOrDie(sizeof(double)*msa->alen);  
  for(i=0; i < msa->alen;i++){
    double l = null_loglikelihood(msa->nseq, cols[i], single_model[0]->freq);
    if (l > single_emission[single_model_num-1][i]) single_emission[single_model_num-1][i] = l;
  }   
  for(k=0; k < single_model_num; k++){
    EvoModel_free(single_model[k]);  
  }
  //single_model_num++;  


  //Pair model
  alloc_node(root, (Alphabet_size+1) * (Alphabet_size+1));	     

  for(k = 0; k < pair_model_num; k++){
    EvoModel_init_pair_gap(pair_model[k]);
    init_node(root, pair_model[k], 1);
    for(i=0; i < msa->alen; i++){
      for(j=i+1; j < msa->alen; j++){
	clear_node(root);      
	init_leaves_pair(leaves, nleaves, seqid2leave, msa->nseq, cols[i],cols[j]);      
	calculate_all_likelihood(leaves, nleaves);
	pair_emission[k][i][j]=root->loglikelihood;	
      }  
    }
  }
  for(k=0; k < pair_model_num; k++){
    EvoModel_free(pair_model[k]);    
  }  
  FreePhytree(root);

  double** logpxy;
  if (use_partition){
    geo_mean_bppr(msa->aseq,msa->nseq, msa->alen, NULL, msa->wgt, &logpxy);
    for(k = 0; k < pair_model_num; k++){
      for(i=0; i < msa->alen; i++){
	for(j=i+1; j < msa->alen; j++){
	  pair_emission[k][i][j] += logpxy[j+1][i+1];
	}
      }
    }
  }
  
  Cube_table* trans_inside_sc = init_Cube_table(0, msa->alen, g->ntransitions, 0);
  Cube_table* inside_sc = init_Cube_table(0, msa->alen, g->nstates, 0);
  Cube_table* outside_sc = init_Cube_table(0, msa->alen, g->nstates, 0);
  inside(0, msa->alen-1, g, inside_sc,trans_inside_sc);
  double total_prob = Cube_table_get(inside_sc, 0, msa->alen-1, g->start);

  outside(0, msa->alen-1,g, outside_sc, inside_sc);  

  double* single_posterior[MAX_MODEL];
  double** pair_posterior[MAX_MODEL];
  printf("Single model num %d\n", single_model_num);
  printf("Pair model num %d\n", pair_model_num);
  
  for(k=0; k < single_model_num ;k++){
    single_posterior[k]= malloc(sizeof(double)*msa->alen);  
  }
  for(k=0; k < pair_model_num;k++){
    pair_posterior[k]= malloc(sizeof(double*)*msa->alen);  
    for(i=0; i < msa->alen; i++){
      pair_posterior[k][i] = malloc(sizeof(double)*msa->alen);  
    }
  }
  Cube_table* trans_posterior = init_Cube_table(0, msa->alen, g->ntransitions, 0);
  Cube_table* state_posterior = init_Cube_table(0, msa->alen, g->nstates, 0);
  posterior(0, msa->alen-1,g,outside_sc, inside_sc, trans_inside_sc, 
	    trans_posterior, state_posterior, single_posterior, pair_posterior);  

  if (!msa->ss_cons){
    msa->ss_cons = (char*) MallocOrDie(sizeof(char) * (msa->alen + 1));
    Cube_table* cyk_table = init_Cube_table(0, msa->alen, g->nstates, 1);  
    double cyk_return=cyk(0, msa->alen-1, g, cyk_table);  
    printf("cyk_return %f\n", cyk_return);
    cyk_traceback(0, msa->alen-1, g, cyk_table, NULL, msa->ss_cons);     
    free_Cube_table(cyk_table);
    for(i=0; i < msa->alen; i++){
      if (!pair_left(msa->ss_cons[i]) && !pair_right(msa->ss_cons[i])){
	msa->ss_cons[i] = '.';
      }
    }
  }
  int* pair_table = GetPairtable(msa->ss_cons);
  double** m = mxy(msa);
  
  printf("Emit:\n");
  double tot_posterior = 0;
  double posterior_loglik = - log( Cube_table_get(trans_posterior, 0, msa->alen -1 , 0));

  for(i=0; i < msa->alen; i++){    
    for(j=i+1; j < msa->alen; j++){
      if (pair_posterior[0][i][j] > 0.1 || pair_table[i]==j){
	double s1 = single_emission[0][i] + single_emission[0][j];
	double s2 = single_emission[1][i] + single_emission[1][j];
	double d = (s1 < s2) ? pair_emission[0][i][j] - s2 : pair_emission[0][i][j] - s1;	
	printf("%d (%.2f emit %.2f :  %.2f emit %.2f ) - %d (%.2f emit %.2f : %.2f emit %.2f) - %.2f emit %.2f diff %.2f", 
	       i, single_posterior[0][i], single_emission[0][i], single_posterior[1][i],single_emission[1][i],	     
	       j, single_posterior[0][j], single_emission[0][j], single_posterior[1][j],single_emission[1][j],	     
	       pair_posterior[0][i][j], pair_emission[0][i][j], d);
	if (use_partition){
	  printf(" pxy %f",  exp(logpxy[j+1][i+1]));	  
	}
	printf("\n");
	if (pair_table[i] == j ){
	  tot_posterior +=  - my_log(1-pair_posterior[0][i][j]);
	  printf("Predicted");
	}
	printf("\n");
      }
    }
  }


  rf = malloc(sizeof(char) * (msa->alen+1));
  for(i=0; i < msa->alen  ; i++){
    rf[i]='.';
  }
  for(i=0; i < msa->alen ; i++){
    j = pair_table[i];
    if (j <= i) {
      if (j == -1) rf[i] = '0' + (char) ((single_posterior[0][i] - 0.001) * 10);      
      continue;
    }
    rf[i]='0' + (char) ((pair_posterior[0][i][j]-0.001)*10);
    rf[j]=rf[i];
  }
  rf[msa->alen]='\0';
  msa->rf = rf;
  
  WriteStockholm(stdout, msa);
  printf("Total pair posterior %.2f\n", tot_posterior);
  printf("Total RNA posterior %.2f\n", posterior_loglik);


  for(k=0; k < single_model_num; k++){
    free(single_emission[k]);
    free(single_posterior[k]);
  }
  for(k=0; k < pair_model_num; k++){
    for(i=0; i < msa->alen; i++){
      free(pair_emission[k][i]);
      free(pair_posterior[k][i]);
    }
    free(pair_emission[k]);
    free(pair_posterior[k]);
  }
  
  free_grammar(g);
  free_Cube_table(inside_sc);
  free_Cube_table(trans_inside_sc);
  free_Cube_table(outside_sc);
  MSAFree(msa); 
}



      
