#include "score.h"

Score::Score(float cutoffb, float nu, int diagsize, double thres, int align, int match, int mismatch, int gapd, int cutoff, int region, int step, int entire, int verbose) {
  this->cutoffb = cutoffb;
  this->nu = nu;
  this->diagsize = diagsize;
  this->thres = thres;
  this->align = align;
  this->match = match;
  this->mismatch = mismatch;
  this->gapd = gapd;
  this->cutoff = cutoff;
  this->region = region;
  this->step = step;
  this->entire = entire;
  this->verbose = verbose;
}

void Score::calcRegion(map<Positions, vector<int> > &interval, string id, string seq, int winsize, int pairsize) { // Run per sequence

  // Calculate local base-pairing probabilities (LBPPs) for the sequence
  LocalProb lp = LocalProb(id, seq, winsize, pairsize, cutoffb, nu, diagsize, thres, region, verbose);
  map<Positions, double> prob; // LBPP list
  lp.calcLocalProb(prob);

  if (region == 1) {
    // Step 1: infer potential intervals for the sequence from the LBPP list
    map<Positions, double> candidate; // Candidate interval list
    lp.enumCandidate(candidate, prob);
    prob.clear();

    // Step 2: select intervals satisfying non-overlapping condition
    lp.nonOverlap(interval, candidate);
    candidate.clear();
  }

  else {
    lp.slideWindow(interval, step); // Take windows from the original sequence
    lp.makeVector(interval, prob); // Make binary vectors for each interval of the sequence
    prob.clear();
  }
}

void Score::calcScore(vector<vector<int> > &score, map<Positions, vector<int> > &interval1, map<Positions, vector<int> > &interval2, string id1, string id2) {

  // Calculate scores between all pairs of intervals across the sequences
  map<Positions, vector<int> >::iterator p, q;
  vector<int> row;

  // Make a score matrix by calculating dot product between intervals
  if (align == 0) {
    for (p=interval1.begin(); p!=interval1.end(); p++) {
      for (q=interval2.begin(); q!=interval2.end(); q++) {
	row.push_back(dotProduct(p->second, q->second, cutoff));
      } // q

      score.push_back(row);
      row.clear();
    } // p
  }

  // Make a score matrix by calculating global alignment between intervals
  else if (align == 1) {
    Alignment al = Alignment(match, mismatch, gapd, cutoff);

    for (p=interval1.begin(); p!=interval1.end(); p++) {
      for (q=interval2.begin(); q!=interval2.end(); q++) {
	row.push_back(al.alignGlobally(p->second, q->second));
      } // q

      score.push_back(row);
      row.clear();
    } // p
  }

  // Display the list of scores
  p = interval1.begin();

  for (int i=0; i<static_cast<int>(score.size()); i++) {
    q = interval2.begin();

    for (int j=0; j<static_cast<int>(score[i].size()); j++) {
      if (score[i][j] >= cutoff) cout << id1 << "[" << p->first.left << ", " << p->first.right << "] " << id2 << "[" << q->first.left << ", " << q->first.right << "] " << score[i][j] << endl;
      ++q;
      if (q == interval2.end()) break;
    } // j

    ++p;
    if (p == interval1.end()) break;
  } // i
}

int dotProduct(vector<int> x, vector<int> y, int cutoff) {
  if (x.size() != y.size()) return -100; // "-100" means illegal dot product

  int sum = 0;

  for (int i=0; i<static_cast<int>(x.size()); i++) {
    sum = sum + x[i] * y[i];
  }

  return sum;
}
