# python libraries
import os
import sys
import getopt
import shutil
import numpy as np
import random
import logging
from collections import Counter, defaultdict
from itertools import izip
from Bio.Seq import Seq
from pyfaidx import Fasta
import logging

from batch import Worker
from ..common import util
from ..common import plotter
from ..sim import cookiesim

logger = logging.getLogger(__name__)

class WellSim(Worker):

  def run(self):

    inputs_map = self.config.inputs
    laneID = inputs_map['laneID']
    wellID = inputs_map['wellID']
    barcodeStart = inputs_map['barcodeStart']

    # sim params
    numWells = self.config.simParams['numWells']
    numBarcodesPerWell = self.config.simParams['numBarcodesPerWell']

    # check overrides
    genomeCovPerWell = None
    if 'genomeCovPerWell' in self.config.simParams:
      genomeCovPerWell = self.config.simParams['genomeCovPerWell']

    fragDistCSV_path = None
    if 'fragDistCSV_path' in self.config.simParams:
      fragDistCSV_path = self.config.simParams['fragDistCSV_path']
    fragSRCovScale = None
    if 'fragSRCovScale' in self.config.simParams:
      fragSRCovScale = self.config.simParams['fragSRCovScale']

    # setup to inject errors
    cookiesim.setup(True)

    # load reference
    #----------------------
    refFasta_path = self.config.get('referenceFasta_path')
    ref_map = Fasta(refFasta_path)

    # filter random contigs
    contigs = filter(
      lambda(k): (
        '_' not in k and
        k != 'chrM'
      ),
      ref_map.keys(),
    )

    # build length map
    len_map = dict(map(
      lambda(k): (k, len(ref_map[k])),
      contigs,
    ))
    genomeSize = sum(len_map.values())

    # determine average genome coverage in a well
    # if override isn't specified
    #----------------------
    if genomeCovPerWell == None:
      statsFnames_list = filter(
        lambda(f): f.endswith('rawstats.p'),
        os.listdir(self.config.wellStats_path),
      )
      statsFnames_list = random.sample(statsFnames_list, 3)
      lrBases_list = []
      for fname in statsFnames_list:
        path = os.path.join(self.config.wellStats_path, fname)
        # iterate over all barcodes in a well
        for (_, stats_map) in util.loadPickle(path).items():
          lrBases_list.append(stats_map['lr-bases'])
      
      minNumBp = np.mean(lrBases_list)
    else:
      minNumBp = genomeCovPerWell * genomeSize

    fq1_fname = '{0}_{1:03d}_1.fq'.format(laneID, wellID)
    fq2_fname = '{0}_{1:03d}_2.fq'.format(laneID, wellID)
    bed_fname = '{0}_{1:03d}_true.bed'.format(laneID, wellID)
    fq1_path = os.path.join(self.config.simWells_path, fq1_fname)
    fq2_path = os.path.join(self.config.simWells_path, fq2_fname)
    bed_path = os.path.join(self.config.simWells_path, bed_fname)

    def pickCloud(
      chrName,
      size,
    ):
      # legal start region needs to be able to fit entire cloud
      validStartMax = len_map[chrName] - size - 20
      assert validStartMax > 0
      # NOTE for now sample cloud starts uniformly at random from the region
      # size
      refSeq = 'n' * 1000
      while refSeq.lower().count('n') > 100:
        cloudStart = random.randrange(validStartMax)
        cloudEnd = cloudStart + size
        #refSeq = Seq(ref_map[chrName][cloudStart:cloudEnd])
        refSeq = Seq(str(ref_map[chrName][cloudStart:cloudEnd]))
      return (cloudStart, refSeq)

    logger.info('generating reads')
    logger.info(' - num barcodes per well: {0}'.format(numBarcodesPerWell))
    step = max(1, numBarcodesPerWell / 20)
    numBp = 0
    numClouds = 0

    # get distribution of long fragments and stencils binned according to
    # the discributino
    (
      fragDist,
      binStencils_map,
      empSRScale,
    ) = self.getCloudDistribution(
      fragDistCSV_path,
      fragSRCovScale,
    )

    with open(fq1_path, 'w') as fq1_f, \
      open(fq2_path, 'w') as fq2_f, \
      open(bed_path, 'w') as bed_f:
      
      # barcode is just an interger index in simulation as opposed to a
      # sequence
      for (i, barcode) in enumerate(xrange(
        barcodeStart,
        barcodeStart + numBarcodesPerWell,
      )):
        if i % step == 0:
          logger.info('  - {0}/{1} barcodes'.format(i, numBarcodesPerWell))
        for cinfo_map in self.sampleStencils(
          minNumBp,
          fragDist,
          binStencils_map,
          empSRScale,
        ):

          # draw chromosome
          (chrName,_) = util.chooseWeightedDict(len_map)
          chrLen = len_map[chrName]

          lpart = cinfo_map['lpart']
          rpart = cinfo_map['rpart']
          size = cinfo_map['size']
          readStencils_list = cinfo_map['reads']

          # pick start of cloud in chromosome
          (cloudStart, refSeq) = pickCloud(chrName, size)

          numClouds += 1
          numBp += size

          # generate reads
          (fragFq1_txt, fragFq2_txt) = cookiesim.getReadsText(
            laneID,
            wellID,
            barcode,
            refSeq,
            chrName,
            cloudStart,
            readStencils_list,
          )
          fq1_f.write(fragFq1_txt)
          fq2_f.write(fragFq2_txt)
          bed_f.write('{0}\t{1}\t{2}\t{3}\n'.format(
            chrName, 
            cloudStart,
            cloudStart + size, 
            '{0};{1};{2};{3}'.format(
              barcode,
              'y' if lpart else 'n',
              'y' if rpart else 'n',
              2 * len(readStencils_list),
            ),
          ))

    logger.info('simulated {0} clouds with covering {1} bp'.format(
      numClouds,
      numBp,
    ))

    return

  def getCloudDistribution(
    self,
    fragDistCSV_path,
    fragSRCovScale,
  ):
    def normalize(x):
      total = sum(x.values(), 0.0)
      y = {}
      for key in x:
        y[key] = x[key] / total
      return y
  
    # load distribution csv if specified
    #----------------------
    if fragDistCSV_path != None:
      logger.info('using empirical dist specified by: {0}'.format(
        fragDistCSV_path))
      fragDist = [(0,0.0)]
      #fragDist = []
      with open(fragDistCSV_path) as f:
        for line in f:
          words = line.split(',')
          assert len(words) == 2
          bin = int(words[0])
          freq = float(words[1])
          fragDist.append((bin, freq))
      fragDist = sorted(fragDist)
      fragDist = normalize(dict(fragDist))
  
    # otherwise put everything into the same bin and shuffle around to get
    # random permutation
    # otherwise just create default bin
    else:
      fragDist = {
        0: 1.,
      }
  
    bins = sorted(fragDist.keys())
  
    # determine number of stencil files to load
    #----------------------
    logger.info('determine empirical distributions and the number of stencil files to load')
    # load some stats files to build empirical distributions
    empSRCovs = []
    statsFnames_list = filter(
      lambda(f): f.endswith('.stats.p'),
      os.listdir(self.config.wellStats_path),
    )
    statsFnames_list = random.sample(statsFnames_list, 10)
    sizes = []
    for fname in statsFnames_list:
      path = os.path.join(self.config.wellStats_path, fname)
      cloudFeats_list = util.loadPickle(path)
      sizes.extend(map(lambda x: x.size, cloudFeats_list))
      # FIXME hardcoded short read length
      cloudFeats_list = filter(
        lambda(cf): cf.size > 10000,
        cloudFeats_list,
      )

      empSRCovs.extend(
        map(lambda x: (1. * x.numReads * 88 / x.size), cloudFeats_list)
      )
  
    hbins = []
    hbins.extend(bins)
    # FIXME hardcoded max fragment size of 10mbp to make np.histogram work
    # it needs and endbin
    hbins.append(10000000)
    freq, _ = np.histogram(sizes, hbins)
    empFragDist = normalize(dict(zip(hbins, freq)))
    assert len(empFragDist) == len(fragDist)
    empSRCov = np.mean(empSRCovs)
  
    # determine how much need to scale bins to properly populate
    if fragDistCSV_path != None:
      binScale_map = {}
      for bin in bins:
        if empFragDist[bin] == 0:
          logger.warn('empty size bin empirically: {0}'.format(bin))
          binScale_map[bin] = None
        if fragDist[bin] == 0:
          logger.warn('empty size bin specified by fragdist csv: {0}'.format(bin))
          binScale_map[bin] = None
  
        binScale_map[bin] = 1. * fragDist[bin] / empFragDist[bin]
        if binScale_map[bin] < 0.5:
          logger.info('size bin {0} is overpopulated empirically'.format(bin))
          logger.info('  - will dump {0} of the stencils in this bin'.format(
            1 - 2 * binScale_map[bin]))
    else:
      binScale_map = {0: 1.}
      
    # compute minimum number of stencil files to open
    # NOTE always load 2x so more than enough
    if fragSRCovScale != None:
      assert fragSRCovScale >= 1
      empSRScale = int(fragSRCovScale)
      logger.info('oversampling {0} stencils to yield {1} short read coverage of long fragment'.format(
        empSRScale,
        empSRScale * empSRCov,
      ))
    else:
      empSRScale = 1.
    numStencilFiles = int(
      2 * 
      max(binScale_map.values()) * 
      empSRScale
    )
    logger.info('need to load {0} stencil files to populate'.format(numStencilFiles))
  
    # load stencil files and place into bins
    #----------------------
    stencilsFnames_list = filter(
      lambda(f): f.endswith('.stencil.p'),
      os.listdir(self.config.wellStats_path),
    )
    stencilsFnames_list = random.sample(stencilsFnames_list, numStencilFiles)
    binStencils_map = defaultdict(list)
    for fname in stencilsFnames_list:
      logger.info('  - loading {0}'.format(fname))
      path = os.path.join(self.config.wellStats_path, fname)
      cloudStencils_list = util.loadPickle(path)
      # skip empty stencil files
      if len(cloudStencils_list) == 0:
        continue
      # digitize all stencils into bins
      sizes_list = map(
        lambda(cinfo_map): cinfo_map['size'],
        cloudStencils_list,
      )
      placements_list = np.digitize(sizes_list, bins)
      for (binidx, cinfo_map) in izip(placements_list, cloudStencils_list):
        cbin = bins[binidx - 1]

        # throw away stencils if very abundent 
        if (
          binScale_map[cbin] < 0.5 and
          random.random() < 1 - binScale_map[cbin] * 2
        ):
          #print 'dumping cloud from bin', cbin
          continue
          
        binStencils_map[cbin].append(cinfo_map)
  
    # assert all placements are in known bins
    assert set(binStencils_map.keys()).issubset(bins)

    logger.info('  - done')

    return (
      fragDist,
      binStencils_map,
      empSRScale,
    )
  
  def sampleStencils(
    self,
    minNumBp,
    fragDist,
    binStencils_map,
    empSRScale,
  ):
    # draw clouds until satisfy total long fragment coverage constraint
    numBp = 0
    while numBp < minNumBp:
      # draw a size bin
      (bin,_) = util.chooseWeightedDict(fragDist)
      # oversample if needed
      if empSRScale > 1:
        # sample a number of stencils and slap on top of each other
        cinfoMap_list = random.sample(binStencils_map[bin], empSRScale)
        reads = [r for cinfo_map in cinfoMap_list for r in cinfo_map['reads']]
        # choose the largest one to get edge effects of barcode
        cinfo_map = max(cinfoMap_list, key=lambda(cinfo_map): cinfo_map['size'])
        numBp += cinfo_map['size']
        # careful to copy values and not clobber original
        mcinfo_map = dict(cinfo_map)
        mcinfo_map['reads'] = reads
        numreads = map(
          lambda(cinfo_map): 2 * len(cinfo_map['reads']),
          cinfoMap_list,
        )
        yield mcinfo_map
  
      else:
        cinfo_map = random.choice(binStencils_map[bin])
        numBp += cinfo_map['size']
        yield cinfo_map
  
    raise StopIteration
  
