#include "phytree.h"
#include "grammar.h"
#include "funcs.h"
#include <float.h>

const  char  *state_tag="States:";
const  char  *transition_tag="Transition:";
const  char  *emission_tag="Emission:";

extern double add_log(double s1, double s2);

typedef struct
{
  int i;
  int j;
  int v;
}Trace_s;

      
char* ParseState (char** s,char* state)
{
  char* begin;
  int  length;   
  char* tmp;  
  while(isspace(**s)) (*s)++;
  if (**s == '<') (*s)++;
  else return NULL;
  begin=*s;
  while(**s!= '>' && **s!='\0') (*s)++;
  length=(*s)- begin;
  if (**s =='>') (*s)++;
  if (length <= 0) return NULL;
  if (state == NULL)
    state = malloc(sizeof(char) * (length + 1));  
  strncpy(state, begin, length);
  state[length] = '\0';
  return state;  
}

 
int getStateIndex(Grammar* g, char* s)
{
  int i;
  if (s== NULL) return -1;
  for(i=0; i < g->nstates; i++){
    if (strcmp(g->stateId[i], s)==0){
      return i;
    }
  }
  return -1;
}

void ParseEmission(char** s, Transition* t)
{
  char* begin;
  int  length;   
  char* tmp;  
  t->emittype= -1;
  t->emitprob= 0;
  while(isspace(**s)) (*s)++;
  if (**s == '[') (*s)++;
  else return;
  begin=*s;
  while(isspace(**s)) (*s)++;
  if(**s =='l'){
    t->emittype=LEFT; 
  }
  else if (**s == 'r'){
    t->emittype=RIGHT;
  }
  else if (**s == 'p'){
    t->emittype=PAIR; 
  }
  else{ 
    Die("Invalid emission type \"%c\"", **s);    
  }
  (*s)++;
  while(**s != ',' && **s != '\0') (*s)++;
  if (**s == ',') (*s)++;
  else return;
  t->emitprob = (int) ParseDouble(s);
  if (**s ==']') (*s)++;
}


void free_grammar(Grammar* g)
{
  int i;
  for(i=0; i < g->nstates; i++){
    free(g->stateId[i]);    
    free(g->in[i]);
    free(g->out[i]);
  }
  free(g->stateId);
  free(g->num_in);
  free(g->num_out);
  free(g->in);
  free(g->out);
  free(g);
}

Grammar*  read_grammar(char* filename)
{
  FILE* fin;
  char  buffer[MAXLINE+1];  
  char  *state;
  char  **stateId;
  int   l, t_count;
  char  *temp;
  if ( (fin = fopen(filename, "r")) == NULL) {
    printf("Fail to open filename %s", filename);    
    exit(1);    
  }
  Grammar *g=(Grammar*)MallocOrDie(sizeof(Grammar));
  while(fgets(buffer, MAXLINE, fin) > 0){  
    if (strncmp(buffer, state_tag, strlen(state_tag)) == 0) break;
  }
  if (fgets(buffer, MAXLINE, fin) > 0){
    temp = buffer;
    g->nstates= (int) ParseDouble(&temp);
    g->stateId= (char**) MallocOrDie(sizeof(char*) * g->nstates);		
    g->num_in= (int*) MallocOrDie(sizeof(int) * g->nstates);		
    g->num_out= (int*) MallocOrDie(sizeof(int) * g->nstates);	
    g->in= (int**) MallocOrDie(sizeof(int*) * g->nstates);		
    g->out= (int**) MallocOrDie(sizeof(int*) * g->nstates);	
    //Maximum 100 transitions total
    for(l=0; l < g->nstates; l++){      
      g->in[l] = (int*)MallocOrDie(sizeof(int)*100);
      g->out[l] = (int*)MallocOrDie(sizeof(int)*100);
      g->num_in[l]=0;
      g->num_out[l]=0;
      g->stateId[l]=NULL;
    }
    l=0;
    while(state=ParseState(&temp, NULL)){
      if ( l >= g->nstates ) Die("Too many states");
      g->stateId[l]= state;
      l++;
    }
    if (g->nstates > l) Die("Too few states");	
    g->start = 0;
    g->end = g->nstates -1;
  }  
  else{
    Die("No states defined");
  }
  
  while(fgets(buffer, MAXLINE, fin) > 0){  
    if (strncmp(buffer, transition_tag, strlen(transition_tag)) == 0)   break;      
  }
  
  t_count = 0;
  while(fgets(buffer, MAXLINE, fin) > 0){    
    char* rbuf;
    int first_rule;
    if (isspace(buffer[0])) break;
    temp=buffer;    
    if ((l=getStateIndex(g,ParseState(&temp, NULL))) >= 0){      
      while(isspace(*temp))temp++;
      if (*temp == '-' && *(temp+1) == '>') temp+=2;
      first_rule=1;
      while(1){
	int r1, r2;	
	if (first_rule)
	  rbuf = (strtok(temp, "|"));
	else
	  rbuf = (strtok(NULL, "|"));
	if (rbuf== NULL) break;
	if (first_rule) first_rule=0;
	Transition* t = &g->transitions[t_count];
	t->state = l;
	t->id = t_count;
	if ( (r1=getStateIndex(g,ParseState(&rbuf, NULL))) < 0)
	  Die("%s Invalid Transition", rbuf);
	if ( (r2=getStateIndex(g,ParseState(&rbuf, NULL))) >= 0){
	  t->children[0]=r1;
	  t->children[1]=r2;
	  t->nbranch=2;
	  g->in[r1][g->num_in[r1]]= t_count;	
	  g->num_in[r1]++;
	  g->in[r2][g->num_in[r2]]= t_count;	
	  g->num_in[r2]++;	  
	}
	else{
	  t->children[0]=r1;
	  t->nbranch=1;
	  g->in[r1][g->num_in[r1]]= t_count;	
	  g->num_in[r1]++;
	}	
	ParseEmission(&rbuf, t);
	double d = ParseDouble(&rbuf);	  	  
	t->prob=d;
	t->logprob=log(d);	
	g->out[l][g->num_out[l]]= t_count;	
	g->num_out[l]++;
	t_count++;
      }    
    }
  }
  g->ntransitions = t_count;
  return g;
}


void write_transition(Transition* t,Grammar* g)
{
  printf("<%s>->", g->stateId[t->state]);
  printf("<%s>", g->stateId[t->children[0]]);
  if (t->nbranch==2){
    printf("<%s>", g->stateId[t->children[1]]);
  }
  if (t->emittype!=-1){
    char c;
    if (t->emittype== LEFT){
      c='l';
    }
    else if (t->emittype== RIGHT){
      c='r';
    }
    else if (t->emittype== PAIR){
      c='P';
    }	  
    printf(" [%c,%d] ",c,t->emitprob);	
  }
  printf("%.3f ", t->prob);
}

void write_grammar(Grammar *g)
{
  int i,j;
  printf("%s\n", state_tag);
  printf("%d ", g->nstates);
  for(i=0; i < g->nstates; i++){
    if (i > 0) printf(" ");
    printf("<%s>", g->stateId[i]);
  }
  printf("\n\n");
  printf("%s\n", transition_tag);
  for(i=0; i < g->nstates; i++){
    if (g->num_out[i] > 0){
      printf("<%s>->", g->stateId[i]);      
      for(j=0; j < g->num_out[i]; j++){
	Transition * t = &g->transitions[g->out[i][j]];
	if (j > 0) printf("|");
	printf("<%s>", g->stateId[t->children[0]]);
	if (t->nbranch==2){
	  printf("<%s>", g->stateId[t->children[1]]);
	}
	if (t->emittype!=-1){
	  char c;
	  if (t->emittype== LEFT){
	    c='l';
	  }
	  else if (t->emittype== RIGHT){
	    c='r';
	  }
	  else if (t->emittype== PAIR){
	    c='P';
	  }	  
	  printf(" [%c,%d] ",c,t->emitprob);	
	}
	printf("%.3f", t->prob);	
      }      
      printf("\n");
    }
  }
}

Cube_table* init_Cube_table(int start, int length, int nstates, int inc_trace)
{
  int size;
  Cube_table* table = malloc(sizeof(Cube_table));
  table->start = start;
  table->length = length;
  table->nstates = nstates;
  size = (length+1) * (length+2) * nstates/2;
  table->size = size;
  table->storage = malloc(size * sizeof(double));
  if (inc_trace){
    table->trace = malloc(size * sizeof(Transition*));    
    table->bifur = malloc(size * sizeof(int));    
  }
  else{
    table->trace=NULL;
    table->bifur=NULL;
  }
  memset(table->storage, 0, size * sizeof(double));
  return table;
}

void free_Cube_table(Cube_table *t)
{
  free(t->storage);
  if (t->trace) free(t->trace);
  if (t->bifur) free(t->bifur);
  free(t);
}

int diag_table_coor(Cube_table* table, int i, int j, int v)
{
  i -= table->start;
  j -= table->start;
  if ( i > j+1) Die("i %d should be smaller than j+2 %d",i,j+2);
  int coor = (j+1) * (j+2) /2 + i ;  
  coor = coor * table->nstates + v;  
  if (coor >= table->size){
    Die("access %d %d %d out of bound %d", i,j,v,table->size);
  }
  return coor;
}

double Cube_table_get(Cube_table* table, int i, int j, int v)
{
  int coor = diag_table_coor(table, i,j,v);
  return table->storage[coor];
}


Transition* Cube_table_get_trace(Cube_table* table, int i, int j, int v)
{
  int coor = diag_table_coor(table, i,j,v);
  return table->trace[coor];
}



int Cube_table_get_bifur(Cube_table* table, int i, int j, int v)
{
  int coor = diag_table_coor(table, i,j,v);
  return table->bifur[coor];
}


void Cube_table_set_all(Cube_table* table, double value)
{
  int i;
  for(i=0; i < table->size; i++){
    table->storage[i] = value;
  }
}

void Cube_table_set(Cube_table* table, int i, int j, int v, double value)
{
  int coor = diag_table_coor(table, i,j,v);
  table->storage[coor] = value;
  //printf("Set %d %d %d at %d: %f\n", i,j,v,coor, value);
}

void Cube_table_set_trace(Cube_table* table, int i, int j, int v, Transition *t,int bifur)
{
  int coor = diag_table_coor(table, i,j,v);
  table->trace[coor] = t;
  table->bifur[coor] = bifur;  
}


int cyk(int start, int end, Grammar *g, Cube_table* cyk_table)
{
  int i,j,v,k,diff ;
  Cube_table_set_all(cyk_table, -INF);
  for (j = start; j<= end; j++) { 
    Cube_table_set(cyk_table, j+1, j, g->end, 0);
  }  
  // recursion
  for (diff=-1; diff <= end - start; diff++){
    for(j= 0; j <= end; j++){      
      i = j - diff;      
      if (i < 0) continue;
      for (v = g->nstates-1; v >= 0; v--) {       
	if (g->end == v)  continue;	
	Transition * best_transition = NULL;
	int    best_bifur=-1;
	double max=-INF;
	for(k=0; k < g->num_out[v]; k++){
	  double s = -INF;
	  double emit=0;
	  double trans=0;
	  int l=i;
	  int r=j;
	  int bifur=-1;	
	  Transition * t = &(g->transitions[g->out[v][k]]);  
	  trans= t->logprob;
	  if(t->emittype ==LEFT){	  	    
	    emit = single_emission[t->emitprob][l];
	    l ++;
	  }
	  else if (t->emittype == RIGHT){
	    emit = single_emission[t->emitprob][r];
	    r --;
	  }
	  else if (t->emittype == PAIR){
	    if (i >=j) continue; //Invalid pair
	    emit = pair_emission[t->emitprob][l][r];
	    l++;
	    r--;
	  }
	  if ( r+1 < l ) continue;
	  if (t->nbranch == 1){
	    s = Cube_table_get(cyk_table, l,r,t->children[0]);	    
	  }
	  else{
	    int n;
	    for(n = l; n < r; n++){
	      double s1=Cube_table_get(cyk_table, l,n,t->children[0]);
	      double s2=Cube_table_get(cyk_table, n+1,r,t->children[1]);
	      if (s1 + s2 > s) {
		s = s1 + s2;
		bifur=n;
	      }	      
	    }
	  }
	  if (s + emit + trans > max){	    
	    best_transition= t;
	    if (t->nbranch==2) {
	      best_bifur= bifur;
	    }
	    else{
	      best_bifur=-1;
	    }
	    max = s + emit + trans;
	    /*
	    if (diff > 10){
	      printf("%d %d %s :",i,j, g->stateId[v]);
	      write_transition(t,g);
	      if (bifur >= 0) printf("mid %d ", bifur);
	      printf("\t Trans %.2f ", trans);
	      printf("Emission %.2f ", emit);
	      printf("Recur %d %d %.2f Total %.2f\n", l,r,s, s + emit + trans);	
	    }	
	    */
	  }	 
	}   
	Cube_table_set(cyk_table, i, j, v, max);
	if (cyk_table->trace){
	  Cube_table_set_trace(cyk_table, i, j, v, best_transition,best_bifur);
	}	
      }
    }
  }
  return Cube_table_get(cyk_table, start, end, g->start);
}

void cyk_parse(char* ss_cons, Grammar *g, Cube_table* cyk_table)
{
  int i,j,v,k,diff;
  int* pair_table = GetPairtable(ss_cons);
  int end = strlen(ss_cons)-1;  

  Cube_table_set_all(cyk_table, 0);
  for (j = 0 ; j<= end; j++) { 
    Cube_table_set(cyk_table, j+1, j, g->end, 1);
  }  

  for (diff=-1; diff <= end ; diff++){
    for(j= diff; j <= end; j++){      
      i = j - diff;      
      if ( i < 0 ) continue;
      for (v = g->nstates-1; v >= 0; v--) {       
	if (g->end == v)  continue;	
	Transition * best_transition = NULL;
	int    best_bifur=-1;
	for(k=0; k < g->num_out[v]; k++){
	  int l=i;
	  int r=j;
	  int bifur=-1;	
	  int find = 0;
	  Transition * t = &(g->transitions[g->out[v][k]]);  
	  if(t->emittype ==LEFT){	  	    
	    if (pair_table[l] >= 0)  continue;  // i is paired
	    l ++;
	  }
	  else if (t->emittype == RIGHT){
	    if (pair_table[r] >= 0) continue;  // j is paired
	    r --;
	  }
	  else if (t->emittype == PAIR){
	    if (pair_table[l] < 0 || pair_table[r] < 0 || pair_table[l] != r) continue;
	    if (i >=j) continue; //Invalid pair
	    l++;
	    r--;
	  }
	  if ( r+1 < l ) continue;
	  if (t->nbranch == 1){
	    if ( Cube_table_get(cyk_table, l,r,t->children[0]) < 0.5) continue;
	    find = 1;
	  }
	  else{
	    int n;
	    for(n = l; n < r; n++){
	      if ( Cube_table_get(cyk_table, l,n,t->children[0]) < 0.5) continue;
	      if ( Cube_table_get(cyk_table, n+1,r,t->children[1]) < 0.5) continue;
	      find = 1;
	      bifur = n;
	      break;
	    }
	  }
	  if (!find) continue;
	  best_transition= t;
	  if (t->nbranch==2) {
	    best_bifur= bifur;
	  }
	  else{
	    best_bifur=-1;
	  }
	  
	  Cube_table_set(cyk_table, i, j, v, 1 );
	  if (cyk_table->trace){
	    Cube_table_set_trace(cyk_table, i, j, v, best_transition,best_bifur);
	  }
	  break;
	}
      }
    }
  }
}

void cyk_traceback(int start, int end, Grammar *g, Cube_table* cyk_table, int* count, char* mark)
{
  //There should be < 100 unresolving trace in the stack
  Trace_s* trace_stack=(Trace_s*) MallocOrDie(sizeof(Trace_s) * 100);
  Trace_s* sp = trace_stack;
  if (mark)  memset(mark, 0, sizeof(char) * (end+1)); 

  sp->i = start;
  sp->j = end;
  sp->v = g->start;
  sp ++;
  while(sp > trace_stack){
    int i,j,n,v;
    Transition* t;
    double s;
    sp--;
    i = sp->i;
    j = sp->j;
    v = sp->v;
    if (v == g->end) continue;
    t = Cube_table_get_trace(cyk_table,i,j,v);
    if (t==NULL) Die("No trace from i %d j %d state %s", i, j, g->stateId[v]);
    s = Cube_table_get(cyk_table,i,j,v);
    /*
      printf("%d %d %s : ", i, j, g->stateId[v]);
      printf(" %.2f ", s);
      write_transition(t,g); 
    */
    if (count) count[t->id] ++;   
    if (t->nbranch==2){
      n = Cube_table_get_bifur(cyk_table,i,j,v);      
    }
    if (t->emittype== PAIR){
      if (mark){
	mark[i] = '<';
	mark[j] = '>';
      }
      i++;
      j--;
    }
    else if (t->emittype== LEFT){
      if (mark) mark[i] = g->stateId[t->state][0];
      i++;
    }
    if (t->emittype== RIGHT){
      if (mark) mark[j] = g->stateId[t->state][0];
      j--;
    }
    if (t->nbranch==2){     
      v = t->children[0];      
      sp->i = i;
      sp->j = n;
      sp->v = v;
      sp++;
      v = t->children[1];
      sp->i = n+1;
      sp->j = j;
      sp->v = v;
      sp++;
    }
    else{
      v = t->children[0];
      if (g->end != v){
	sp->i = i;
	sp->j = j;
	sp->v = v;
	sp++;
      }
    }
  }
  free(trace_stack);
  if (mark) mark[end+1]= '\0';
}

void check_emission(int n, char* label)
{
  int i,j;
  for(i=0; i <= n; i++){
    if (single_emission[0][i] < 0 && single_emission[0][i] > - 100 &&
	single_emission[1][i] < 0 && single_emission[1][i] > - 100) continue;
    Die("%s: single_emission %d went wrong \n",label,i);
  }  
  for(i=0; i <= n; i++)
    for(j=i+1; j <=n; j++){
      if (pair_emission[0][i][j] < 0 && pair_emission[0][i][j] > - 1000) continue;
      Die("%s: pair_emission %d %d went wrong %f\n", label,i,j, pair_emission[0][i][j]);
    }
}


int inside(int start, int end, Grammar *g, Cube_table* inside_table, Cube_table* trans_inside_table)
{
  int i,j,v,k,diff ;
  double s;
  double emit;
  double trans;
  double sum;	  
  int l;
  int r;
  Transition * t;

  //printf("Inside\n");
  Cube_table_set_all(inside_table, -INF);
  for (j = start -1; j<= end; j++) { 
    Cube_table_set(inside_table, j+1, j, g->end, 0);
  }
  Cube_table_set_all(trans_inside_table, -INF);
  // recursion
  for (diff=-1; diff <= end - start; diff++){
    for(j= diff; j <= end; j++){      
      i = j - diff;      
      if (i < start)  continue;
      for (v = g->nstates-1; v >= 0; v--) {       	
	if (g->end == v)  continue;	
	//printf("i %d j %d v %d %s:\n", i,j,v,g->stateId[v]);
	double       inside_score = - INF;
	for(k=0; k < g->num_out[v]; k++){
	  s = -INF;
	  emit=0;
	  trans=0;
	  sum=0;	  
	  l=i;
	  r=j;
	  t = &(g->transitions[g->out[v][k]]);
	  trans= t->logprob;
	  if(t->emittype ==LEFT){	  	  	    
	    emit = single_emission[t->emitprob][l];
	    l ++;
	    if ( r+1 < l ) continue;
	  }
	  else if (t->emittype == RIGHT){
	    emit = single_emission[t->emitprob][r];
	    r --;
	    if (r + 1 < l ) continue;
	  }
	  else if (t->emittype == PAIR){
	    if (l + 1 > r-1) continue; //Invalid pair
	    emit = pair_emission[t->emitprob][l][r];
	    l ++;
	    r --;
	    if (r + 1 < l ) continue;
	  }
	  if (t->nbranch == 1){
	    s = Cube_table_get(inside_table, l,r,t->children[0]);	    
	  }
	  else{
	    int n;
	    for(n = l; n < r; n++){
	      double s1=Cube_table_get(inside_table, l, n,t->children[0]);
	      double s2=Cube_table_get(inside_table,n+1,r,t->children[1]);
	      s = add_log(s, s1 + s2);
	      /*
	      printf("s1 (%d %d %d) %f s2 (%d %d %d) %f comb %f s=%f\n", 
		     l,n, t->children[0], s1, n+1,r, t->children[1], s2, s1+s2, s);
	      if (s1 > -INF && s2 > - INF){
		printf("(i %d j %d v %s %.2f) - (i %d j %d v %s f %.2f) comb %.2f s %.2f\n", 
		       l, n, g->stateId[t->children[0]], s1, 
		       n+1,r,g->stateId[t->children[1]], s2, s1+s2, s);
	      }
	      */
	    }
	  }
	  if (s < - INF + 1) continue;
	  sum = s + emit + trans;
	  if (trans_inside_table){
	    Cube_table_set(trans_inside_table, i,j, g->out[v][k], sum);
	  }
	  //check_emission(end, "L5");
	  inside_score = add_log(inside_score, sum);	  
	  /*
	  printf("i %d j %d c %s r %.2f e %.2f t %.2f sum %.2f ac %.2f ", l,r, 
		 g->stateId[t->children[0]], s, emit, trans, sum, inside_score);	  
	  write_transition(t, g);	  
	  printf("\n");
	  */
	}
	Cube_table_set(inside_table, i, j, v, inside_score);	
	/*
	if (inside_score > -INF + 1){
	  printf("i %d j %d %s inside %.2f\n", i,j, g->stateId[v], inside_score); 
	}
	*/
	//check_emission(end, "L6");
      }
    }
  }
  //printf("Inside total %f\n", Cube_table_get(inside_table, start, end, g->start));
  return Cube_table_get(inside_table, start, end, g->start);
}
  

void outside(int start, int end, Grammar *g, Cube_table* outside_table, Cube_table* inside_table)
{
  int i,j,diff,v,u,k,n,l,r ;
  //printf("Outside DP\n");
  Cube_table_set_all(outside_table, -INF);
  Cube_table_set(outside_table, start, end, g->start, 0);

  // recursion

  for (diff = end - start ; diff >= 0; diff--){  
    for(i=start; i + diff<= end; i++){    
      j = i +  diff;      
      for (v =0; v < g->nstates; v++){
	if (g->end == v)  continue;	
	if (i==start && j== end && v== g->start) continue;
	double outside_sc = - INF;
	for (k=0; k < g->num_in[v]; k++){
	  Transition* t = &(g->transitions[g->in[v][k]]);	  	  	  
	  double s=0;	    
	  double trans = 0;
	  double emit = 0;
	  double recur = 0;
	  int  u = t->state;
	  l = i;
	  r = j;
	  trans = t->logprob;
	  if (t->emittype == LEFT){
	    l --;
	    if ( l < start) continue;
	    emit = single_emission[t->emitprob][l];
	  }
	  else if (t->emittype == RIGHT){
	    r++;
	    if ( r > end) continue;
	    emit = single_emission[t->emitprob][r];
	  }
	  else if (t->emittype == PAIR){
	    l--;
	    r++;
	    if ( l < start || r > end) continue;
	    emit = pair_emission[t->emitprob][l][r];		  
	  }
	  //add outside log odds
	  if (t->nbranch==1){
	    recur=Cube_table_get(outside_table, l,r, u);		
	    if (recur < -INF + 1) continue;
	    /*
	    write_transition(t, g);
	    printf("Recursion i %d  j %d u %s, recur %.2f\n", l, r, g->stateId[u], recur);
	    */
	  }
	  else{
	    double tmp_sum= -INF;
	    double recur_out, recur_in;
	    if (t->children[0] == v){
	      for ( n= r+1; n <= end; n++){		    
		recur_out= Cube_table_get(outside_table, l,n,u);
		recur_in = Cube_table_get(inside_table, j+1,n,t->children[1]);
		if (recur_out < -INF + 1 || recur_in < -INF +1) continue;
		tmp_sum = add_log(tmp_sum, recur_out + recur_in);
		/*
		write_transition(t,g);
		printf("Recursion out i %d j %d %s %.2f in i %d j %d %s %.2f accum %.2f\n", l, n, g->stateId[u], recur_out, 
		       j+1, n, g->stateId[t->children[1]], recur_in,tmp_sum);
		*/
	      }
	    }
	    else if (t->children[1] == v){
	      for ( n= start; n < i; n++){		    
		recur_out = Cube_table_get(outside_table, n,r,u) ;
		recur_in = Cube_table_get(inside_table, n,i-1,t->children[0]);
		if (recur_out < -INF + 1 || recur_in < -INF +1) continue;
		tmp_sum = add_log(tmp_sum, recur_out + recur_in);		
		/*
		write_transition(t,g);
		printf("Recursion out i %d j %d %s %.2f in i %d j %d %s %.2f accum %.2f\n", n, r, g->stateId[u], recur_out, 
		       n, i-1, g->stateId[t->children[0]], recur_in,tmp_sum);
		*/
	      }
	    }
	    recur = tmp_sum;
	  }
	  s = emit + trans + recur;
	  if (s < -INF + 1) continue;
	  outside_sc = add_log(outside_sc, s);
	  /*
	  printf("i %d j %d v %d %s: ", i,j, v, g->stateId[v]);
	  printf("t %.2f e %.2f r %.2f total %.2f outside %.2f\n", trans, emit,recur, s,outside_sc);
	  */
	}
	Cube_table_set(outside_table, i, j, v, outside_sc);
	/*
	if (outside_sc > - INF){
	  printf("%d %d %s Outside %.2f\n", i,j, g->stateId[v], outside_sc); 
	}
	*/
      }
    }
  }
}

    

void posterior(int start, int end, Grammar *g, 
	       Cube_table* outside_table, 
	       Cube_table* inside_table, 
	       Cube_table* trans_inside_table, 
	       Cube_table* transition_posterior, 
	       Cube_table* state_posterior, 
	       double* single_posterior[],
	       double** pair_posterior[])
{
  int i,j,v,k,diff;
  double total_prob = Cube_table_get(inside_table, start, end, g->start);
  //printf("Total prob %.2f\n", total_prob);
  if (transition_posterior){
    Cube_table_set_all(transition_posterior, 0);
  } 
  if (state_posterior){
    Cube_table_set_all(state_posterior, 0);
  } 
  for(k=0; k < single_model_num; k++){
    for(i=start; i <= end; i++){      
      single_posterior[k][i]= 0;
    }    
  }
  for(k=0; k < pair_model_num; k++){
    for(i=start; i <= end; i++){      
      for(j=start; j <= end; j++){      
	pair_posterior[k][i][j]= 0;
      }    
    }
  }
  
  for (diff = end - start ; diff >= 0; diff--){  
    for(i=start; i + diff<= end; i++){    
      j = i +  diff;      
      for (k =0; k < g->ntransitions; k++){
	Transition* t = &(g->transitions[k]);
	v = t->state;
	double trans_inside_sc = Cube_table_get(trans_inside_table, i,j, k);	
	if (trans_inside_sc < - INF + 1) continue;
	double inside_sc = Cube_table_get(inside_table, i,j, v);
	double out_sc = Cube_table_get(outside_table, i,j, v);
	if (out_sc < - INF + 1) continue;
	double p = exp(out_sc + inside_sc  - total_prob);		
	double tp = exp(out_sc + trans_inside_sc - total_prob);

	/*
	if (diff > 0.8 * (end-start) && p > 0.001){
	  printf("i %d j %d ", i, j);
	  write_transition(t, g);
	  printf("t %d - in %.3f trans_in %.3f out %.3f p %.3f tp %.3f\n", k,inside_sc, trans_inside_sc, out_sc, p,tp);
        }
	*/

	if (transition_posterior){	  
	  Cube_table_set(transition_posterior, i,j,k, tp);
	}
	if (state_posterior){
	  if (Cube_table_get(state_posterior,i,j,v) == 0){
	    Cube_table_set(state_posterior, i,j,v, p);
	  }
	}
	if (t->emittype != -1){
	  if (t->emittype == LEFT && single_posterior[t->emitprob]){
	    /*
	    if (tp > 0.01)
	      printf("i %d Add %f to single %d %f new %f\n", i, tp, t->emitprob, 
		     single_posterior[t->emitprob][i], single_posterior[t->emitprob][i]+tp);
	    */
	    single_posterior[t->emitprob][i] += tp;
	  }
	  else if (t->emittype == RIGHT && single_posterior[t->emitprob]){	    
	    /*
	    if (tp > 0.01)	     
	      printf("j %d Add %f to single %d %f new %f\n", j, tp, t->emitprob, 
		     single_posterior[t->emitprob][i], single_posterior[t->emitprob][i] + tp);
	    */
	    single_posterior[t->emitprob][j] += tp;
	  }
	  else if (t->emittype == PAIR && pair_posterior[t->emitprob]){
	    /*
	    if (tp > 0.01)
	      printf("i %d j %d Add %f to pair %d %f new %f\n", i,j, tp, t->emitprob, 
	      pair_posterior[t->emitprob][i][j], pair_posterior[t->emitprob][i][j]+tp); 
	    */
	    pair_posterior[t->emitprob][i][j] += tp;
	  }	  
	}		
      }
    }
  }      


  for(i=start; i <= end; i++){
    double sum=single_posterior[0][i] + single_posterior[1][i];
    for(j=i; j <=end; j++){
      sum+= pair_posterior[0][i][j];
    }
    for(j=start; j <=i-1; j++){
      sum+= pair_posterior[0][j][i];
    }
    if (abs(1-sum) > 0.001){
      printf("Error %d sum %.2f\n", i, sum);
    }
  }

}


