import os
import random
import numpy as np
from itertools import izip

from Bio.Seq import Seq
from ..common import util

# global counter which serves as a unique identifier
readCount = 0
errorReadCount = 0
errorCount = 0
posErrorCount = 0

maxErrorProbSum = 0
avgErrorProbSum = 0

fastqTmplTxt = \
'''@{0}
{1}
+
{2}
'''

leftBases = {
  'a' : 'cgt',
  'c' : 'agt',
  'g' : 'cat',
  't' : 'cga',
  'n' : 'n',
}
# iterate over all possible phred scores and build map of actual error
# probs
maxPhred = 93
errorProb_map = {Q : 10 ** (-1.*Q/10) for Q in xrange(maxPhred)}

#-------------------------------------------------------------------------
# generate base qualities + substitution errors for a given read
#-------------------------------------------------------------------------
injectErrors = False
empDist = None

def setup(_injectErrors):
  global injectErrors
  global empDist
  injectErrors = _injectErrors
  if injectErrors:
    root_path = os.path.dirname(os.path.abspath(__file__))
    empDist_path = os.path.realpath(os.path.join(
      root_path,
      'empDist.p'
    ))
    assert os.path.isfile(empDist_path)
    empDist = util.loadPickle(empDist_path)

    #for (baseidx, baseQuality_map) in sorted(empDist.items()):
    #  print 'cycle ', baseidx
    #  for (base, Qcounter) in sorted(baseQuality_map.items()):
    #    tot = sum(Qcounter.values())
    #    avgQ = 0
    #    avgError = 0
    #    for (Q, cnt) in Qcounter.items():
    #      avgQ += 1. * (Q - 33) * cnt / tot
    #      avgError += 1. * errorProb_map[Q-33] * cnt/tot
    #    print '  - {0}: {1} {2}'.format(base, avgQ, avgError)
    #die

def getBaseQualities(readSeq):
  if not injectErrors:
    return '2' * len(readSeq)

  readQual = ''
  # iterate through all read bases
  for (baseidx, base) in enumerate(readSeq):
    # choose base quality based on empirical distribution of the source
    # base + read cycle number (baseidx)
    (baseQual, _) = util.chooseWeightedDict(empDist[baseidx][base.lower()])
    readQual += chr(baseQual)
  return readQual

def injectSubErrors(readSeq, readQual):
  if not injectErrors:
    return (readSeq, [])

  global errorCount
  global posErrorCount
  global maxErrorProbSum
  global avgErrorProbSum
  global errorReadCount

  newReadSeq = ''
  flips = []
  errorProb_list = []
  # iterate through bases and flip to a random base if chosen to error
  for (baseidx, (prebase, qualChar)) in enumerate(izip(readSeq, readQual)):
    base = prebase.lower()
    baseQual = ord(qualChar) - 33
    errorProb = errorProb_map[baseQual]
    errorProb_list.append(errorProb)
    if base != 'n' and random.random() <= errorProb:
      newReadSeq += random.choice(leftBases[base])
      flips.append(
        (baseidx, readSeq[baseidx], newReadSeq[baseidx])
      )
      posErrorCount += baseidx
      errorCount += 1
    else:
      newReadSeq += base

  maxErrorProbSum += max(errorProb_list)
  avgErrorProbSum += np.mean(errorProb_list)
  errorReadCount += 1

  #print 'quals ', readQual
  #print 'readseq ', readSeq
  #print 'newreadseq ', newReadSeq
  #print 'flips ', flips
  #print 'num errors ', len(flips)
  return (newReadSeq, flips)

#-------------------------------------------------------------------------
# generate reads text for this cloud position and stencil
#-------------------------------------------------------------------------
def getReadsText(
  laneID,
  wellID,
  barcode,
  refSeq,
  chrName,
  cloudStart,
  readStencils,
):
  global readCount
  # FIXME hardcoded
  readLength = 100
  #readLength = 88

  def getQname(r1pos, r2pos, isBarcode):
    # bump to genome wide position coordinates so can encode true location
    # of read
    r1pos += cloudStart
    # set filler read 2 position to be 0 if unpaired
    r2pos = 0 if r2pos == None else r2pos + cloudStart

    return '{0}${1}-{2}-{3}_{4}-{5}-{6}_{7}'.format(
      barcode,
      readCount,
      laneID,
      wellID,
      chrName,
      r1pos,
      r2pos,
      'bar' if isBarcode else 'r',
    )

  reads1Text = ''
  reads2Text = ''
  for (r1pos, r1isRev, r2pos, r2isRev, isBarcode) in readStencils:
    qname = getQname(r1pos, r2pos, isBarcode)

    read1Seq = ''
    pre_read1Seq = refSeq[r1pos:r1pos + readLength]
    read1Seq = str(pre_read1Seq) if not r1isRev else \
      str(pre_read1Seq.reverse_complement())
    read1Qual = getBaseQualities(read1Seq)
    (read1Seq, flips) = injectSubErrors(read1Seq, read1Qual)

    read2Seq = ''
    pre_read2Seq = refSeq[r2pos:r2pos + readLength]
    read2Seq = str(pre_read2Seq) if not r2isRev else \
      str(pre_read2Seq.reverse_complement())
    read2Qual = getBaseQualities(read2Seq)
    (read2Seq, flips) = injectSubErrors(read2Seq, read2Qual)

    r1Txt = fastqTmplTxt.format(
      qname,
      read1Seq,
      read1Qual,
    )
    r2Txt = fastqTmplTxt.format(
      qname,
      read2Seq,
      read2Qual,
    )
    reads1Text += r1Txt
    reads2Text += r2Txt
    readCount += 2

  return (reads1Text, reads2Text)


