// Copyright (c) 2005 Broad Institute of MIT and Harvard
//

/// This class generates a basevector from a flow.
///
/// \class BaseCaller
///
/// It uses the values from the key, does a few passes of 
/// braindead EM-like adjustment of the means, and happily calls away.
///    
/// Note that this was designed to call bases on even completely 
/// useless flows, so that we can then align them to the sequences 
/// and see if they are good or not. Therefore, it has no quality
/// control built in.

#include "CoreTools.h"
#include "String.h"
#include "ParseSet.h"
#include "Basevector.h"
#include "STLExtensions.h"
#include "PackAlign.h"
#include "pairwise_aligners/SmithWatFree.h"
#include "math/Functions.h"

#include "454/flowdata/FlowOrder.h"
#include "454/flowdata/FlowVector.h"
#include "454/flowdata/FlowDataBroker.h"
#include "454/flowdata/FlowKey.h"
#include "454/BaseCaller.h"

//#include <algorithm>
#include <numeric>
#include <cmath>

const basevector * BaseCaller::BAD_BASES = BaseCaller::SetBadBases();

basevector * BaseCaller::SetBadBases() {
  basevector * ret = new basevector();
  ret->resize(MAX_BASES);
  for (int i=0; i != MAX_BASES; ++i) ret->Set(i,0);
  return ret;
}


BaseCaller::BaseCaller(const FlowOrder & order,
    int maxPasses):
  mOrder(order),
  mMaxPasses(maxPasses)
{
}

void BaseCaller::GetBases(const FlowVector & fvec, basevector & bases,
			  vec<int> * correspondence) {
  vec<double> cutoffs;
  FlowKey key(mOrder,fvec);
  if (key.Bad()) {
    bases.Setsize(0);
    return;
  }
  vec<normal_distribution> dist = FindDistributions(fvec, key);
  GetBasesWithDistributionsAndCutoffs(fvec, bases, dist, cutoffs, 
				      correspondence);
}

void BaseCaller::GetBasesWithTrivialCutoffs(const FlowVector & fvec, 
					    basevector & bases, 
					    vec<int> * correspondence) {
  double maxf=0.f;
  for (int i=0; i!= fvec.isize(); ++i) {
    if (fvec(i) > maxf) maxf = fvec(i);
  }
  GetBasesWithTrivialCutoffs(fvec, bases, int (maxf+1), correspondence);
}

void BaseCaller::GetBasesWithTrivialCutoffs(const FlowVector & fvec, 
					    basevector & bases, int N,
					    vec<int> * correspondence) {
  vec<double> cutoffs(N);
  for (int i=0; i != N; ++i) {
    cutoffs[i] = i + 0.5;
  }
  GetBasesWithCutoffs(fvec, bases, cutoffs, true, correspondence);
}

void BaseCaller::GetBasesWithDistributionsAndCutoffs
(const FlowVector & fvec, basevector & bases,
 const vec<normal_distribution> dist,
 const vec<double> & cutoffs,
 vec<int> * correspondence) {
  
  //If no cutoffs are provided, base them on a small set of clean data,
  // as in palvarez's file FlowGraphs.xls.
  vec<double> tempcut(cutoffs);
  if (tempcut.empty()) {
    tempcut.push_back( (dist[0].mu_ + dist[1].mu_) /(2.0 * dist[1].mu_),
                       1.47, 2.41, 3.41, 4.41, 5.41, 6.41, 7.41);
  }
  tempcut *= dist[1].mu_;
  
  //cout << "cutoffs are " << tempcut << endl;

  //Call the bases.
  GetBasesWithCutoffs(fvec, bases, tempcut, false, correspondence);
}

void BaseCaller::GetBasesWithNoiseDependentCutoffs
(const FlowVector & fvec, basevector & bases,
 const QuadraticFunction & a, const QuadraticFunction & b,
 float mean0, vec<int> * correspondence) {
  double maxf= (*max_element(fvec.begin(), fvec.end())) * fvec.Multiplier();
  //PRINT(maxf);
  vec<double> cutoffs;
  
  // stop the loop when the last cutoff created is greater than the greatest
  // value in the flowvector.
  for (int i=0; i==0 || cutoffs[i-1] < maxf; ++i) {
    cutoffs.push_back( a(i) * mean0 + b(i) );
    //PRINT4(i,cutoffs[i],mean0, maxf);
  }
  GetBasesWithCutoffs(fvec, bases, cutoffs, true, correspondence);
}
  

void BaseCaller::GetBasesWithCutoffs(const FlowVector & fvec, 
				     basevector & bases, 
				     const vec<double> & cutoffs, 
				     bool useDoubleValue,
				     vec<int> * correspondence) {
  bases.resize(MAX_BASES);
  if (correspondence) (correspondence->resize(MAX_BASES));
  //Call the bases.
  int base=0;
  int height;
  const int S=fvec.size();
  for (int i = 0; i != S; ++i) {
    if (mOrder[i] == 'P') continue;
    for (height = 0; height != cutoffs.isize(); ++height) {
      //if we want to use doubles, we must use operator(),
      // whereas if we are using unsigned short, we must use operator[].
      if (useDoubleValue && fvec(i) > cutoffs[height]
          || !useDoubleValue && fvec[i] > cutoffs[height]) {
	if (correspondence) (*correspondence)[base]=i;
	bases.Set(base++, as_char(mOrder[i]));
	if (base >= MAX_BASES) {
	  //Something went very wrong. Don't crash, but notify the user.
	  cout << "Problem: read with more than " << MAX_BASES << " bases" << endl;
	  bases.Setsize(0);
	  if (correspondence) correspondence->resize(0);
	  return;
	}
      } else {
	break; //done with this flow
      }
    }
  }
  ForceAssertLe(base, MAX_BASES);
  bases.resize(base);
  if (correspondence) (correspondence->resize(base));
}

vec<double> BaseCaller::FindNewMeans(const FlowVector & fvec, 
    const vec<double> & means) {
  const int S = means.size();
  AssertGt(S,1);

  //Set up the bins
  vec<double> cutoffs(S);
  for (int i = 0; i < S-1; ++i) {
    cutoffs[i] = (means[i] + means[i+1]) /2;
  }
  cutoffs[S - 1] = means[S-1] + means[S-1] - cutoffs[S-2];
  vec<double> sums(S);
  vec<int> counts(S);

  //Put the data in bins so we can calculate the means.
  for (int i = 0; i != fvec.isize(); ++i) {
    if ('P' == mOrder[i]) continue; //ignore Ps
    for (int bin=0; bin < S; ++bin) {
      if (fvec[i] < cutoffs[bin]) {
	sums[bin] += fvec[i];
	++counts[bin];
	break;
      }
    }
  }

  //Calculate and return the means.
  vec<double> ret(S);
  for (int i = 0; i != S; ++i) {
    if (counts[i]) {
      ret[i] = sums[i] / counts[i];
    } else {
      ret[i] = 0; //This should force a return value of BAD_BASES.
    }
  }
  return ret;
}


vec<normal_distribution> 
BaseCaller::FindDistributions(const FlowVector & fvec, const FlowKey & key) {
  const int S=fvec.size();
  vec<double> means(2);
  vec<normal_distribution> ret;

  // Deal with keys: if one was passed in, it must match
  FlowKey mykey(mOrder, fvec);
  if (key.Bad() || mykey != key) {
    FatalErr("Key " << key.ToString() << "not found for flowvector "
             << fvec << endl);
  }

  //Set up initial values for noise and peak1 based on the key
  means[0] = key.NegMean(fvec);
  means[1] = key.PosMean(fvec);

  double start = means[1];

  vec<double> oldmeans;
  double error;
  SquaredError<double> ss;
  if (!is_sorted_strict(means.begin(), means.end())) { //Failure
    return ret;
  }
  //Do at most mMaxPasses cycles of means optimization.
  for (int cycle = 0; cycle != mMaxPasses; ++cycle) {
    // cout << cycle << " passes, means are: " << means << endl;
    oldmeans = means;
    means = FindNewMeans(fvec, oldmeans);
    error=sqrt(std::inner_product( means.begin(), means.end(), 
	  oldmeans.begin(), 0.0,
	  plus<double>(), ss) / S);
    if  (error < start/10) break;
  }
  //cout << "Final means are: " << means << endl;
  if (!is_sorted_strict(means.begin(), means.end())) { //Failure
    return ret;
  }
  ret.resize(2);
  ret[0].mu_ = means[0];
  ret[1].mu_ = means[1];
  return ret;
}
