/* extend_motif.c
 * Zizhen Yao
 *
 * CVS $Id: extend_motif.c,v 3.1 2006/03/07 19:38:36 yzizhen Exp $.
 *
 * Adjust the motif boundary. The usuer can specify the amount of adjustment,
 * e.g. the number of columns to be extend/shrink. If no parameters are 
 * specified, the program try to guess the new motif boundary, by extending 
 * first to see if the nearby regions are conserved. 
 * If not, it will try shrink the motif boundary: 
 * first remove the weak base pairs, then shrink the unconserved unpaired 
 * columns at the boundary. 
 */  
  
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>


#include "squid.h"		/* general sequence analysis library    */
#include "msa.h"                /* squid's multiple alignment i/o       */
#include "structs.h"		/* data structures, macros, #define's   */
#include "funcs.h"		/* external functions                   */
#include "version.h"            /* versioning info for Infernal         */

#ifdef MEMDEBUG
#include "dbmalloc.h"
#endif

#define NEGINFINITY  -9999999
#define START 0
#define END   1
#define EXPAND 0
#define SHRINK 1


double prior[]={0.25, 0.25, 0.25,0.25};
extern void bppr(char **aseq, int nseq, int alen, float* weights, double ***ret_pr);


struct motif_coor_s {
  int            seq_id;
  int            start;
  int            end;
};
typedef struct motif_coor_s MotifCoor;

#define pair_left(c) (c == '(' || c == '<' || c == '{')
#define pair_right(c) (c == ')' || c == '>' || c == '}')


int count_bases(char* seq, int start, int end)
{
  int i, count=0;
  for(i= start; i < end; i++){
    if (!isgap(seq[i])) count++;
  }
  return count;
}


int* pair_table(char* ss, int len)
{
  int  i,j;
  int  sp=0;  
  int* stack = malloc(sizeof(int) * len);
  int* pt = malloc(sizeof(int) * len);
  for(i=0; i < len; i++) {
    pt[i] = -1;    
  }
  
  for(i=0; i < len; i++) {
    if (pair_left(ss[i])) {
      stack[sp++] = i;      
    }
    if (pair_right(ss[i])){
      --sp;
      if (sp < 0) {
	Die("unbalanced base pair at pos %d", i);
      }
      j = stack[sp];
      pt[j] = i;
      pt[i] = j;      
    }	    
  }
  return pt;  
}


MotifCoor* parse_motif_coor(MSA* msa, SQINFO* sqinfo, int nseq)
{
  MotifCoor* coors = MallocOrDie(sizeof(MotifCoor) * msa->nseq);
  int i,j;
  for( i = 0; i < msa->nseq;i++){
    coors[i].seq_id == -1;
    if (msa->sqdesc && sscanf(msa->sqdesc[i], "%d..%d", &coors[i].start, &coors[i].end)==0){
      Die("No valid coordinates are found for motif %i %s", i, msa->sqname[i]);
    }
    for(j=0; j < nseq; j++){
      if (strcmp(msa->sqname[i], sqinfo[j].name)==0){
	coors[i].seq_id = j;
	break;
      }
    }
    if (coors[i].seq_id==-1){
      Die("No valid sequence are found for motif  %i %s", i, msa->sqname[i]);
    }
  }
  return coors;
}

void strncpy_pad(char* dest, int dest_len, char* src, int src_len, char pad_c)
{
  int i;
  if (dest_len < src_len) Die("Dest string %s can not be shorter than the Src %s string\n", dest, src);
  if (src_len > 0 && src){
    strncpy(dest, src, src_len);
  }
  for(i=src_len; i < dest_len; i++)
    dest[i] = pad_c;
  dest[dest_len]='\0';
}

double conserve_score(int nseq, char **dsq, int* pos, MotifCoor* coors, SQINFO *sqinfo, float* weight)
{
  int    i, idx;  
  double e, tot_weight,col_weight;
  double f[Alphabet_size];        /* singlet frequency vector            */
  int    sym;
  double max_freq=0;

  tot_weight = 0;
  col_weight = 0;
  for (sym = 0; sym <Alphabet_size; sym++)
    f[sym] = 0;    
  for (idx = 0; idx < nseq; idx++){
    int seq_id = coors[idx].seq_id;
    tot_weight += weight[idx];    
    if (pos[idx] <= 0 || sqinfo[seq_id].len < pos[idx]) continue;      
    sym =  dsq[seq_id][pos[idx]];
    if (sym == DIGITAL_GAP) continue;    
    f[sym] += weight[idx];
    col_weight += weight[idx];
  }      

  e = 0;    
  for (sym = 0; sym < Alphabet_size; sym++){
    f[sym] /= col_weight;            
    e -=  prior[sym] * log(prior[sym]);
    if (f[sym] > 0) {
      e +=  f[sym] * log(f[sym]);            
      if (f[sym] > max_freq) max_freq = f[sym];
    }
  }    
  e *=  1.44269504;        

  if (max_freq == 0) return -1;
  return e * col_weight / tot_weight;
}

double alignment_conserve_score(MSA* msa, char **msa_dsq, int pos)
{
  int    i, idx;  
  double e, tot_weight,col_weight;
  double f[Alphabet_size];        /* singlet frequency vector            */
  int    sym;
  double max_freq=0;

  tot_weight = 0;
  col_weight = 0;
  for (sym = 0; sym <Alphabet_size; sym++)
    f[sym] = 0;    
  for (idx = 0; idx < msa->nseq; idx++){
    tot_weight += msa->wgt[idx];
    sym =  msa_dsq[idx][pos];
    if (sym == DIGITAL_GAP) continue;    
    f[sym] += msa->wgt[idx];
    col_weight += msa->wgt[idx];
  }      

  e = 0;    
  for (sym = 0; sym < Alphabet_size; sym++){
    f[sym] /= col_weight;            
    e -=  prior[sym] * log(prior[sym]);
    if (f[sym] > 0) {
      e +=  f[sym] * log(f[sym]);            
      if (f[sym] > max_freq) max_freq = f[sym];
    }
  }    
  e *=  1.44269504;        

  if (max_freq == 0) return -1;
  return e * col_weight / tot_weight;
}

int extend_motif(MSA* msa, MotifCoor* coors, char** dsq,SQINFO* sqinfo, int direction)
{
  int motif_ext = 0;
  int min_score=1;
  int max_gap = 10;
  int i,j;
  int* pos = MallocOrDie(sizeof(int) * msa->nseq);
  double tot_conserved_score=0;
  double best_conserved_score=0;
  double score;  
 
  //Extend the end
  for(i=1; ; i++){
    int j;
    if (direction == END){
      for(j=0; j < msa->nseq; j++) pos[j] = coors[j].end +i;
    }
    else{
      for(j=0; j < msa->nseq; j++) pos[j] = coors[j].start - i;    
    }
    double score = conserve_score(msa->nseq, dsq, pos, coors, sqinfo, msa-> wgt);      
    if (score < 0) break;
    tot_conserved_score += score - min_score;
    if (tot_conserved_score > best_conserved_score){
      best_conserved_score = tot_conserved_score;
      motif_ext = i;
    }
    if (i > motif_ext + max_gap && score < min_score) break;
  }
  free(pos);
  return motif_ext;
}


void shrink_motif(MSA* msa, int* ret_left_offset, int* ret_right_offset)
{
  int i, j,k;
  int left_block_start=-1, right_block_start=-1;
  char   **msa_dsq;
  double **bp_pr=NULL;
  int min_score=1;
  int block_size = 3;
  double min_bppr = 0.3;
  int* pt = pair_table(msa->ss_cons, msa->alen);  
  int left_fixed=0;
  int right_fixed = 0;

  i = 0;
  j = msa->alen-1;

  bppr(msa->aseq,msa->nseq, msa->alen, msa->wgt, &bp_pr);

  //Removing weak base pairs at the boundary
  while(i < j){
    if (i > pt[i]) { i++; continue;}
    //printf("%d %d\t %.2f\n", i, pt[i], bp_pr[pt[i]+1][i+1]);
    if (bp_pr[pt[i]+1][i+1] < min_bppr){
      msa->ss_cons[i] = '.';
      msa->ss_cons[pt[i]]='.';
      //printf("Remove base pair at %d %d %f\n", i, pt[i],bp_pr[pt[i]][i]);
      for(k=0; k < msa->nseq; k++){
	msa->ss[k][i] = '.';
	msa->ss[k][pt[i]] = '.';	
      }
      pt[pt[i]]=-1;
      pt[i]=-1;
      i++;
    }
    else break;
  }
  while(i < j){
    if (pt[j] == -1 || j < pt[j]) { j--; continue;}    
    //printf("%d %d\t %.2f\n", pt[j],j, bp_pr[j+1][pt[j]+1]);
    if (bp_pr[j+1][pt[j]+1] < min_bppr){
      msa->ss_cons[j] = '.';
      msa->ss_cons[pt[j]] = '.';
      //printf("Remove base pair at %d %d bp %f\n",  pt[j],j,bp_pr[j][pt[j]]);
      for(k=0; k < msa->nseq; k++){
	msa->ss[k][j] = '.';
	msa->ss[k][pt[j]] = '.';	
      }      
      j--;
    }
    else break;    
  }
  i = 0;
  j = msa->alen-1;

  msa_dsq = DigitizeAlignment(msa->aseq, msa->nseq, msa->alen);
  left_block_start = -1;
  right_block_start = msa->alen;
  while(i < j && (!left_fixed || !right_fixed)){
    if (!left_fixed){
      if(pair_left(msa->ss_cons[i])){
	left_fixed = 1;	
	continue;
      }
      double score = alignment_conserve_score(msa, msa_dsq,i);
      if (score > min_score){
	if (i - left_block_start > block_size){
	  left_fixed = 1;
	  i = left_block_start + 1;
	  continue;
	}	
      }      
      else{
	left_block_start = i;	
      }
      i++;	
    }

    if (!right_fixed){
      if(pair_right(msa->ss_cons[j])){
	right_fixed = 1;	
	continue;
      }
      double score = alignment_conserve_score(msa, msa_dsq,j);
      if (score > min_score){
	if ( right_block_start - j > block_size){
	  right_fixed = 1;
	  j = right_block_start - 1;
	}	
      }      
      else{
	right_block_start = j;	
      }
      j--;
    }    
  }
  *ret_left_offset = - i;
  *ret_right_offset = j+1 - msa->alen;  

  //printf("Shrink left %d, right %d\n", i, msa->alen - 1 -j);
}


void fetch_sequence(MSA* msa, MotifCoor* coors, char** rseqs,SQINFO* sqinfo, int nseq, int motif_ext_start, int motif_ext_end)
{
  int i, j;

  int old_len = msa->alen;
  //Fetch sequence
  if (abs(motif_ext_start) >0 || abs(motif_ext_end) >0){

    fprintf(stderr, "Old alignment length %d\n", msa->alen);
    char* temp1 = MallocOrDie(sizeof(char) * (msa->alen+1));
    char* temp2 = MallocOrDie(sizeof(char) * (msa->alen+1));

    msa->alen += motif_ext_start + motif_ext_end;  
    fprintf(stderr, "New alignment length %d\n", msa->alen);

    for(i=0; i < msa->nseq; i++){
      int seq_id = coors[i].seq_id;
      int left_offset = motif_ext_start > 0 ? motif_ext_start : - count_bases(msa->aseq[i], 0, -motif_ext_start);
      int right_offset= motif_ext_end   > 0 ? motif_ext_end : 
	- count_bases(msa->aseq[i],  old_len + motif_ext_end,  old_len ); 
      int new_motif_start =  coors[i].start - left_offset > 0 ? coors[i].start - left_offset : 1;

      int new_motif_end = coors[i].end + right_offset <= sqinfo[seq_id].len ?
	coors[i].end + right_offset : sqinfo[seq_id].len ;

      strcpy(temp1, msa->aseq[i]);      
      strcpy(temp2, msa->ss[i]);
      if (motif_ext_start + motif_ext_end > 0){
	msa->aseq[i] = realloc(msa->aseq[i], (msa->alen+1) * sizeof(char));
	msa->ss[i] = realloc(msa->ss[i], (msa->alen+1) * sizeof(char));
      }

      memset(msa->aseq[i], sizeof(char) * (msa->alen+1), 0);
      memset(msa->ss[i], sizeof(char) * (msa->alen+1), 0);
      
      if (motif_ext_start > 0){
	strncpy_pad(msa->aseq[i], motif_ext_start, rseqs[seq_id] + new_motif_start - 1, coors[i].start - new_motif_start, '.'); 
	strncpy_pad(msa->ss[i], motif_ext_start, NULL, 0 , '.');
	strncat(msa->aseq[i], temp1, msa->alen);
	strncat(msa->ss[i],   temp2, msa->alen);
      }
      else{
	strncpy(msa->aseq[i], temp1 - motif_ext_start, msa->alen);
	strncpy(msa->ss[i],   temp2 - motif_ext_start, msa->alen);	
      }

      if (motif_ext_end > 0){
	strncpy_pad(msa->aseq[i]+ msa->alen - motif_ext_end,motif_ext_end,rseqs[seq_id] + coors[i].end, new_motif_end -coors[i].end,'.');
	strncpy_pad(msa->ss[i]+   msa->alen - motif_ext_end,motif_ext_end,NULL, 0,'.');
      }

      msa->aseq[i][msa->alen]='\0';      
      msa->ss[i][msa->alen]='\0';      

      coors[i].start = new_motif_start;
      coors[i].end = new_motif_end;

      //printf("Suf   len %3d %s\n", strlen(msa->aseq[i]), msa->aseq[i]);	
      sprintf(msa->sqdesc[i], "%d..%d", coors[i].start, coors[i].end);      
    }

    strcpy(temp1, msa->rf);      
    strcpy(temp2, msa->ss_cons);
    msa->rf = realloc(msa->rf, (msa->alen+1) * sizeof(char));    
    msa->ss_cons = realloc(msa->ss_cons, (msa->alen+1) * sizeof(char));
    if (motif_ext_start > 0){
      strncpy_pad(msa->rf, motif_ext_start, NULL, 0 , '.'); 
      strncpy_pad(msa->ss_cons, motif_ext_start, NULL, 0 , '.');
      strncat(msa->rf,      temp1, msa->alen - motif_ext_start + 1);
      strncat(msa->ss_cons, temp2, msa->alen - motif_ext_start + 1);
    }
    else{
      strncpy(msa->rf,      temp1 - motif_ext_start, msa->alen);
      strncpy(msa->ss_cons, temp2 - motif_ext_start, msa->alen);      
    }
    
    if(motif_ext_end >0){
      strncpy_pad(msa->rf       + msa->alen - motif_ext_end, motif_ext_end,NULL, 0,'.');
      strncpy_pad(msa ->ss_cons + msa->alen - motif_ext_end, motif_ext_end,NULL, 0,'.');
    }
    msa->rf[msa->alen]='\0';      
    msa->ss_cons[msa->alen]='\0';      
    free(temp1);
    free(temp2);
  }  
}

static struct opt_s OPTIONS[] = {
  { "-l", TRUE, sqdARG_INT}, 
  { "-r", TRUE, sqdARG_INT},  
  { "-h", TRUE, sqdARG_NONE},
};


static char usage[]  = "\
Usage: extend_motif [-options] <motif_file> <sequene_file> \n\
where options are:\n\
     -l <num>: the number of bases to extend (positive values) or shrink (negative values) to the left \n\
     -r <num>: the number of bases to extend (positive values) or shrink (negative values) to the right\n\
               If neither -l and -r are not provided, the program will predict the new motif boundary automatically \n\
     -h      : print short help and version info\n\
";

#define NOPTIONS (sizeof(OPTIONS) / sizeof(struct opt_s))

int
main(int argc, char **argv)
{
  int        format;              /* alifile format                            */
  char      *seqfile=NULL;        /* training sequence file                    */
  char      *alifile=NULL;        /* file contain the initial alignment of selected cand */

  char 	    **rseqs;	          /* training sequences                        */
  char      **dsq;                /* Digitized training sequences              */
  SQINFO    *sqinfo;		  /* array of sqinfo structures for rseqs      */
  int  	    nseq;		  /* number of seqs */                           
  MSA       *msa;
  MotifCoor *coors;               /* motif coordinates relative to the sequence */

  char  *optname;                /* name of option found by Getopt()        */
  char  *optarg;                 /* argument found by Getopt()              */
  int    optind;                 /* index in argv[]                         */	

  int    temp;  
  int    i, j;
  int    motif_ext_start=0, motif_ext_end=0;
  
  
  /*Parse command line */	
  while (Getopt(argc, argv, OPTIONS, NOPTIONS, usage,
                &optind, &optname, &optarg))  {
    if      (strcmp(optname, "-l") == 0)        motif_ext_start = atoi(optarg);     
    else if (strcmp(optname, "-r") == 0)        motif_ext_end   = atoi(optarg);
    else if (strcmp(optname, "-h") == 0) {
      puts(usage);
      exit(EXIT_SUCCESS);
    }    
  }
  
  format               = MSAFILE_STOCKHOLM;  

  if (argc < optind + 2){
    puts(usage);
    exit(1);
  }
  alifile = argv[optind++];  
  seqfile = argv[optind++];


  /*********************************************** 
   * Get sequence data
   ***********************************************/
  /* read the training seqs from file */
  if (! ReadMultipleRseqs(seqfile, SQFILE_FASTA, &rseqs, &sqinfo, &nseq))
    Die("Failed to read any sequences from file %s", seqfile);
  
  /* Preprocess */
  for (i = 0; i < nseq; i++){
    PrepareSequence(rseqs[i]);
  }  

  MSAFILE     *afp = NULL;        /* file handle of initial alignment          */    
  if ((afp = MSAFileOpen(alifile, format, NULL)) == NULL)
    Die("Alignment file %s could not be opened for reading", alifile);
  if ((msa = MSAFileRead(afp)) != NULL)
    {
      for (i = 0; i < msa->nseq; i++){
	PrepareSequence(msa->aseq[i]);
      }	
      /* Estimate CM */
      MSAFileClose(afp);
    }


  if (motif_ext_start < - msa->alen || motif_ext_end < - msa->alen || motif_ext_start + motif_ext_end < - msa->alen){
    Die("Invalid length for shrinkage");
  }
  
  coors = parse_motif_coor(msa, sqinfo, nseq);

  dsq = (char **) malloc(sizeof(char *) * nseq);
  for (i = 0; i < nseq; i++) {    
    dsq[i] = DigitizeSequence(rseqs[i], sqinfo[i].len);
  }  

  if (motif_ext_start == 0 && motif_ext_end==0){
    motif_ext_start = extend_motif(msa,coors,dsq,sqinfo, START);
    motif_ext_end =   extend_motif(msa,coors,dsq,sqinfo, END);
  }
  if (motif_ext_start == 0 && motif_ext_end==0){
    shrink_motif(msa, &motif_ext_start, &motif_ext_end);
  }

  fetch_sequence(msa, coors, rseqs, sqinfo, nseq, motif_ext_start, motif_ext_end);
  WriteStockholm(stdout, msa);     
  MSAFree(msa);
  Free2DArray((void **)dsq, nseq);
  for (i = 0; i < nseq; i++)
    FreeSequence(rseqs[i], &(sqinfo[i]));
  free(sqinfo);  
}
