import os
import sys
import pysam
import logging

from ..common import util
from ..toolflow import samtools
from .cloud import TenXCloud, MoleculoCloud

logger = logging.getLogger(__name__)

#=========================================================================
# cloud finder
#=========================================================================
class CloudFinder(object):

  def __init__(
    self,
    bam_path,
    multiMapMode=False,
    saveReads=True,
  ):
    self.saveReads = saveReads
    self.multiMapMode = multiMapMode 

    self.bam_path = bam_path
    self.multiMapMode = multiMapMode 

    self.cloud_cls = None

    # reads which could not even be mapped
    self.unmappedReads_list = []
    # reads which were mapped but did not end up in a valid cloud
    self.invalidMappedReads_list = []

    # obtain index stats so we have total number of reads
    info_map = samtools.getIndexStats(self.bam_path)
    if info_map != None:
      self.numReads = info_map['sum'][1] + info_map['sum'][2]
    # failed due to lack of bam index
    else:
      self.numReads = samtools.getTotalCounts(self.bam_path)

    # file iteration state
    self.step = max(1, self.numReads / 10)
    self.readsIter = None
    self.readIdx = 0
    self.finished = False
    self.headRead = None
    # keep track if clouds were actually iterated over for a particular
    # barcode space
    # if not, we will have to phantom step to the next barcode space
    self.stepped = False

  def __getReadBarcode__(self, read):
    raise NotImplementedError()

  def __reset__(self):
    self.readsIter = pysam.Samfile(self.bam_path, 'rb')
    self.readIdx = 0
    self.finished = False
    self.stepped = False
    self.headRead = None

  def __stepCounts__(self):
    self.stepped = True
    if self.readIdx % self.step == 0:
      logger.info('  - {0}/{1}'.format(self.readIdx, self.numReads))
    self.readIdx += 1

  def getCloudsByBarcode(self):
    self.__reset__()
    self.headRead = next(self.readsIter)
    while not self.finished:
      barcode = self.__getReadBarcode__(self.headRead)
      self.stepped = False
      cloud_iter = self.__getCloudIter__(
        self.__getBarcodeReadsIter__(barcode)
      )
      yield (barcode, cloud_iter)
      # step through clouds in this barcode space if not iterated thru
      if not self.stepped:
        for _ in cloud_iter: pass
    
    self.readsIter.close()
    raise StopIteration

  def getClouds(self):
    self.__reset__()
    currbarcode = None
    for cloud in self.__getCloudIter__(self.__getReadsIter__()):
      yield cloud

    self.readsIter.close()
    raise StopIteration

  def __getReadsIter__(self):
    # ensure all reads have the same barcode
    currbarcode = None
    for read in self.readsIter:
      self.__stepCounts__()
      barcode = self.__getReadBarcode__(read)
      assert currbarcode == None or barcode == currbarcode
      currbarcode = barcode
      yield read

    raise StopIteration

  def __getBarcodeReadsIter__(self, barcode):
    # collect all reads in this barcode space, sort, and yield
    pushbackRead = None
    reads_list = []
    if self.headRead:
      assert self.__getReadBarcode__(self.headRead) == barcode
      reads_list.append(self.headRead)
    for read in self.readsIter:
      self.__stepCounts__()

      readBarcode = self.__getReadBarcode__(read)
      if readBarcode != barcode:
        pushbackRead = read
        break

      reads_list.append(read)

    # check if we ended a barcode space and save the head read for the
    # next group
    if pushbackRead:
      self.headRead = pushbackRead
    # otherwise we finished processing the whole *bam
    else:
      self.finished = True

    return sorted(reads_list, key=lambda(r): (r.tid, r.pos))

  def __getCloudIter__(self, readsIter):

    # helper to "close" a cloud once the last read has been added
    def closeCloud(cloud):

      # trim straggler reads and yield an invalid cloud for each one
      trimReads_list = cloud.trimReads()

      cloud.classify()

      # save all _mapped_ reads not used in any valid cloud as invalid
      # mappings
      # NOTE these are to be used in future for local realignment
      if cloud.invalid:
        self.invalidMappedReads_list.extend(
          cloud.getReads()
        )
      if not self.saveReads:
        cloud.dropReads()

    cloud = self.cloud_cls()

    seenRid_set = set()
    for read in readsIter:

      # skip unmapped
      if read.is_unmapped:
        self.unmappedReads_list.append(read)
        continue

      # skip secondary
      if not self.multiMapMode and read.is_secondary:
        continue

      # ensure only one primary alignment per read if not in multi map mode
      rid = (read.qname, read.is_read1)    
      assert not (
        not self.multiMapMode and
        rid in seenRid_set
      )
      seenRid_set.add(rid)

      # NOTE pysam is kind of dumb and doesn't keep the rname around for
      # each read so need to use samfile handle to obtain
      readRname = self.readsIter.getrname(read.tid)

      # FIXME possibly check for "truncated" reads if this turns out to be
      # a problem
      pass

      # add to current cloud if it passes as a member
      if cloud.isMemberRead(read):
        cloud.addAlnRead(read, readRname)
      # otherwise complete this cloud and create a new cloud to place the
      # next read
      else:
        closeCloud(cloud)

        # yield only valid clouds
        if not cloud.invalid:
          yield cloud
  
        cloud = self.cloud_cls()
        cloud.addAlnRead(
          read,
          readRname,
        )

    # check if tail cloud valid
    closeCloud(cloud)
    if not cloud.invalid:
      yield cloud

    raise StopIteration

#=========================================================================
# Moleculo
#=========================================================================
class CloudFinderMoleculo(CloudFinder):

  def __init__(
    self,
    bam_path,
    **kwargs
  ):
    super(CloudFinderMoleculo, self).__init__(bam_path, **kwargs)

    self.cloud_cls = MoleculoCloud

  def __getReadBarcode__(self, read):
    # for now all reads in a *bam partition belong to the same well
    return None

  # FIXME this should really be renamed to 'end-marker' or something like
  # this, 'barcode' is overloaded now
  def isBarcode(self, read):
    return self.cloud_cls.isBarcode(read)

#=========================================================================
# Moleculo
#=========================================================================
class CloudFinderTenX(CloudFinder):

  def __init__(
    self,
    bam_path,
    **kwargs
  ):
    super(CloudFinderTenX, self).__init__(bam_path, **kwargs)

    self.cloud_cls = TenXCloud

  def __getReadBarcode__(self, read):
    # check attached as a tag
    filt_list = filter(lambda(k, v): k == 'BX', read.tags)
    if filt_list != []:
      return filt_list[0][1]
    # otherwise check if encoded in the qname
    else:
      words = read.qname.split('$')
      return None if len(words) == 1 else words[0]

  # 10x does not have barcodes
  def isBarcode(self, read):
    return False

