/////////////////////////////////////////////////////////////////////////////
//                   SOFTWARE COPYRIGHT NOTICE AGREEMENT                   //
//       This software and its documentation are copyright (2006) by the   //
//   Broad Institute/Massachusetts Institute of Technology.  All rights    //
//   are reserved.  This software is supplied without any warranty or      //
//   guaranteed support whatsoever. Neither the Broad Institute nor MIT    //
//   can be responsible for its use, misuse, or functionality.             //
/////////////////////////////////////////////////////////////////////////////

#include "454/qual/PhredTableWriter.h"
#include "454/qual/PhredTableReader.h"
#include "Map.h"

PhredTableWriter::~PhredTableWriter() {}

/// Write out the text table and the compacted table to a file.
void PhredTableWriter::WritePhredTable(const String & fname,
				       PredictorParameterHandler * h) const{
  Ofstream(os, fname);
  PrintPhredTable(os, h);
  PhredTableReader reader(*this);
  reader.Compact();
  reader.WriteCompacted(os);
}
  
/** Run the Phred algorithm to make a table of useful threshold values.

\param predictors: 2-D matrix of predictor values, each column
 (represented as a pointer to a vec<float>) corresponding to one
predictor.  Number of predictors must equal size of thresholds
vector.
      
\param thresholds: For each predictor, a vector of the
thresholds chosen.  The number of thresholds does not need to be
equal for all predictors.  The thresholds within each vec will
get sorted and uniqued.
      
\param incorrect: For each sample, a vector holding 0 if
correct, something else otherwise.  Length must equal the length
of each predictor column.
*/

void PhredTableWriter::MakePhredTable(const vec<float> & incorrect) {
  const unsigned int npredictors = thresholds.size();
  const unsigned int nsamples = incorrect.size();
  AssertEq(npredictors, predictors.size());
  AssertEq(nsamples, predictors[0]->size());
  cout << "Processing " << npredictors << " predictors on "
       << nsamples << " samples" << endl;

  vec<Tuple > allCombos; 
  //create all the combinations of thresholds.
  CreateCombos(thresholds, allCombos);

  //Initialize the bin mapper, so ToBins will work.
  InitBinMapper();

  //For each possible threshold combination, load data into it.
  cout << "Loading data" << endl;
  LoadCorrectWrong(allCombos, incorrect);

  //Save the error information along each individual predictor
  //This can be useful for diagnostic purposes and for making graphs
  Project(allCombos);

  // Compute the sums, once
  cout << "Calculating sums" << endl;
  CalculateSums(allCombos);

  //Now all the data is in allCombos. We keep picking the one with
  //the best quality, saving its data, and then subtracting its
  //totals from the totals of all bins (or subtract the "lcd" totals
  //for the incomparable bins).  We stop when there is no more data
  //(as indicated by row < 0).
  int row;
  do {
    if (allCombos.size()<100)
      PrintAllCombos(cout, allCombos);
    row = FindBestRow(allCombos);
    if (row >= 0) {
      Tuple chosen(allCombos[row]); //save a copy before subtracting
      cout << "Chose tuple " << chosen << endl;
      vec<Tuple>::iterator it = allCombos.begin();
      for ( ; it!=allCombos.end(); ++it) {
	subtractIfNotLE(*it, chosen, allCombos);
      }
      for (it=allCombos.begin(); it!=allCombos.end(); ++it) {
	zeroIfLE(*it, chosen);
      }

      //save the thresholds and qual scores from the tuple into m_table.
      vec<float> rowvec;
      for (unsigned int i=0; i != chosen.size(); ++i) {
	rowvec.push_back(thresholds[i][chosen[i]]);
      }
      rowvec.push_back(chosen.SumBases());
      rowvec.push_back(chosen.QualScore());
      m_table.push_back(rowvec);
    }
  } while (row>=0);
}

// Subtract the bins at or below o from x.
void PhredTableWriter::subtractIfNotLE(Tuple &x, const Tuple &o, 
				       const vec<Tuple> &v) const {
  if (o <= x) { 
    // If o<=x, then all of o is included in x's sums.
    x.sumCorrect -= o.sumCorrect;
    x.sumWrong -= o.sumWrong;
  } else if (!(x <= o)) {
    // If o and x are incomparable, we locate the bin which is the
    // "LCD" of x and o, and subtract that from x.
    const unsigned int i = FindTuple(x.indices(), o.indices(), dimensions_);
    x.sumCorrect -= v[i].sumCorrect;
    x.sumWrong -= v[i].sumWrong;
  }
  // If x<=o, then x should be cancelled out, but not yet
}

// Subtract the bins at or below o from x.
void PhredTableWriter::zeroIfLE(Tuple &x, const Tuple &o) const {
  if (x <= o) { 
    x.sumCorrect = 0;
    x.sumWrong = 0;
  }
}

///Load data from predictors and incorrect into the appropriate Tuple.
/// Use ToBins(), just like we do in CompactLookup(), to make sure that
/// the code is consistent.
void PhredTableWriter::LoadCorrectWrong(vec<Tuple> & allCombos,
					const vec<float> & incorrect) const {
  const unsigned int npredictors = thresholds.size();
  const unsigned int nsamples = incorrect.size();
  AssertEq(npredictors, predictors.size());
  AssertEq(nsamples, predictors[0]->size());

  for (unsigned int i = 0; i != incorrect.size(); ++i) {
    vec<unsigned int> indices(npredictors);
    vec<float> values(npredictors);
    for (unsigned int j=0; j != npredictors; ++j) {
      values[j] = (*predictors[j])[i];
    }
    ToBins(values, indices);
    unsigned int t = FindTuple(indices, dimensions_);
    if (AlignErr::MATCH == incorrect[i]) ++allCombos[t].correct;
    else ++allCombos[t].wrong;
  }
}

void PhredTableWriter::Project(vec<Tuple> & allCombos) {

  const unsigned int npredictors = thresholds.size();
  AssertEq(npredictors, predictors.size());

  // Set up the vector sizes.
  predProjection_.resize(npredictors);
  pair<int,int> zero(0,0);
  for (int i=0; i != thresholds.isize(); ++i) {
    predProjection_[i].assign(thresholds[i].size(), zero);
  }

  // project the n-dimensional table onto each dimenstion
  for (ulonglong t=0; t != allCombos.size(); ++t) {
    vec<unsigned int> indices = FindIndices(t, dimensions_);
    for (int i=0; i != indices.isize(); ++i) {
      predProjection_[i][indices[i]].first += allCombos[t].correct;
      predProjection_[i][indices[i]].second += allCombos[t].wrong;
    }
  }
}

void PhredTableWriter::PrintPredErrorRates(ostream & os) const {

  ForceAssert(!predProjection_.empty());
  ForceAssertEq(predProjection_.size(), thresholds.size());
  for (int p=0; p != predProjection_.isize(); ++p) {
    ForceAssertEq(predProjection_[p].size(), thresholds[p].size());
    os << "Predictor " << p << endl;
    for (int bin=0; bin != thresholds[p].isize(); ++bin) {
      os << bin << "\t" << thresholds[p][bin] << "\t"
	 << predProjection_[p][bin].first << "\t"
	 << predProjection_[p][bin].second << "\t"
	 << QualScore(predProjection_[p][bin].first,
		      predProjection_[p][bin].second)
	 << endl;
    }
  }
}


void PhredTableWriter::InitBinMapper() {
  ForceAssert(!thresholds.empty());
  binMapper_.clear();
  const int NPREDICTORS = thresholds.size();
  for (int i=0; i!= NPREDICTORS; ++i) {
    //set up bin mapper
    binMapper_.push_back(FloatMap(thresholds[i]));
  }
}


///Create all Tuples of thresholds, in the order implied by 
/// FindIndices and FindTuple
void PhredTableWriter::CreateCombos(vec<vec<float> > & thresholds,
    vec<Tuple> & allCombos) {

  unsigned int N=1;
  for (int i=0; i != thresholds.isize(); ++i) {
    sort(thresholds[i].begin(), thresholds[i].end());
    thresholds[i].erase(unique(thresholds[i].begin(), thresholds[i].end()),
			thresholds[i].end());
    ForceAssertGt(thresholds[i].size(), 0);
    N *= thresholds[i].size();
  }
  cout << "Computing on " << N << " table entries " << endl;
  // Initialize dimensions member
  for (int i=0; i != thresholds.isize(); ++i) {
    dimensions_.push_back(thresholds[i].size());
  }
  allCombos.resize(N, Tuple(thresholds.size()));
  for (unsigned int i=0; i != N; ++i) {
    allCombos[i].SetIndices(FindIndices(i, dimensions_));
  }
}


/// Return the row number of the best tuple.
/// That's the one with the highest qual score and, in case of ties, 
/// the highest number of bases. If that is still a tie, take the one
/// with the highest sum of indices.
int PhredTableWriter::FindBestRow(vec<Tuple> & allCombos) const {
  //now pick the best tuple and return its position.
  int maxqual = 0;
  int maxbases=0;
  unsigned int maxSumIndices=0;
  int index=-1;
  for (int i=0; i != allCombos.isize(); ++i) {
    if (allCombos[i].SumBases() < MIN_BASES) continue;
    if (allCombos[i].QualScore() > maxqual
	|| (allCombos[i].QualScore() == maxqual
	    && (allCombos[i].SumBases() > maxbases 
		|| (allCombos[i].SumBases() == maxbases &&
		    allCombos[i].sumIndices() > maxSumIndices) ) ) ) {
      maxqual = allCombos[i].QualScore();
      maxbases = allCombos[i].SumBases();
      maxSumIndices = allCombos[i].sumIndices();
      index = i;
    }
  }
  //return -1 if no bases are left.
  //PRINT(maxbases);
  return 0 == maxbases ? -1 : index;
}


///Sum up all the correct and wrong bases "under" each Tuple.
///"Under" means that the bin containing those bases must have
///thresholds all of which are equal to or lower than the thresholds
///of the Tuple of interest. So we also count the bases in the Tuple
///of interest.  We use the principle of inclusion and exclusion to
///accelerate the computation.
void PhredTableWriter::CalculateSums(vec<Tuple> & allCombos) const {
  const unsigned int nOffsets = (1 << dimensions_.size());
  vec<unsigned int> indices(dimensions_.size());
  unsigned int bit, parity, updateIndex, shifted;
  bool inBounds;
  for (unsigned int i=0; i != allCombos.size(); ++i) {
    Tuple & t = allCombos[i];
    // Include t's values in its own sum
    t.sumCorrect += t.correct;
    t.sumWrong += t.wrong;
    // We use the numbers from 1 to 2^d -1 as a convenient source of
    // all the not-all-0 bit strings of length d (think in binary)
    for (unsigned int j=1; j <nOffsets; ++j) {
      // Translate number to bit string, which is used as offset,
      // and count parity while you're at it
      parity = 0;
      shifted = j;
      inBounds = true;
      for (unsigned int k=0; k<indices.size(); ++k, shifted >>= 1) {
	bit = shifted & 1;
	indices[k] = t[k]+bit;
	if (indices[k] >= dimensions_[k]) {
	  inBounds=false;
	  break;
	}
	parity += bit;
      }
      if (inBounds) {
	// We add sum from t to the indices above it, with appropriate
	// sign from principle of inclusion & exclusion
	updateIndex = FindTuple(indices, dimensions_);
	if (parity % 2) {
	  allCombos[updateIndex].sumCorrect += t.sumCorrect;
	  allCombos[updateIndex].sumWrong += t.sumWrong;
	} else {
	  allCombos[updateIndex].sumCorrect -= t.sumCorrect;
	  allCombos[updateIndex].sumWrong -= t.sumWrong;
	}
      }
    }
  }
}

void PhredTableWriter::findThresholds(const vec<float> &newpredictor, 
				      int nThresholds)
{
  vec<float> p(newpredictor); // COPY newpredictor
  int N = p.size(), binSize = N/nThresholds;
  thresholds.push_back(vec<float>());
  // Determine (nThresholds-1) values so that there are binSize
  // samples in each bin
  vec<float>::iterator it=p.begin();
  for (int i=0; i<nThresholds-1; ++i) {
    it+=binSize;
    nth_element(p.begin(), it, p.end());
    thresholds.back().push_back(*it);
  }
  // Now it points to a position about 1 bin back from the end, and
  // the top binSize items are between it and end.  Put the greatest
  // value in our data into the last bin, because the last element is
  // effectively infinity (values greater than the last get put into
  // the last bin).
  thresholds.back().push_back(*max_element(it,p.end()));
}

void PhredTableWriter::addDiscretePredictor(const vec<float> &newpredictor, 
					    int targetBins,
					    const String & name)
{
  int N = newpredictor.size(), binSize = N/targetBins;
  map<float, int> values;
  for (unsigned int i=0; i<newpredictor.size(); ++i)
    ++values[(newpredictor[i])];
  //PRINT(values);
  thresholds.push_back(vec<float>());
  map<float, int>::iterator it=values.begin();
  float currValue = it->first;
  int binTotal = it->second;
  while (it!=values.end() && binSize>0) {
    while (binTotal < binSize) {
      ++it;
      if (it==values.end()) 
	break;
      currValue = it->first;
      binTotal += it->second;
    }
    //PRINT4(currValue, binSize, binTotal, N)
    // Recalculate binSize target on the fly 
    N -= binTotal;
    --targetBins;
    binSize = 0 == N ? 0 : N/targetBins;
    binTotal=0;
    thresholds.back().push_back(currValue);
    //PRINT(thresholds.back());
  }
  cout << "Using " << thresholds.back().size() << " bins ";
  predictors.push_back(&newpredictor);
  predNames.push_back(name);
}

