import os
import sys
import pysam
import logging

from ..common import util


class Cloud(object):

  # cloud rules
  @property
  def MIN_NUM_READS(self): raise NotImplementedError()
  @property
  def MAX_READ_GAP(self): raise NotImplementedError()
  @property
  def MIN_DENSITY(self): raise NotImplementedError()

  @property
  def numreads(self): return self._numreads
  @property
  def readgap(self): return self._readgap

  def __init__(self):
    # cloud position
    self.chrName = None
    self.minpos = None
    self.maxpos = None

    # end marker positioning
    self.mixed = False
    self.lpart = False
    self.rpart = False

    # list of all member reads
    # NOTE obtained by iterating over coordinate sorted *bam
    self.reads_list = []

    # maintain set of unique rids
    self.rid_set = set()
    self.qname_set = set()

    # whether or not member reads constitute a valid cloud
    self.invalid = None

    self._numreads = 0
    self._readgap = 0

  def classify(self):

    self._numreads = len(self.rid_set)
    # require minimium size and number of reads
    if self.numreads < self.MIN_NUM_READS:
      self.invalid = True
      return

    self.invalid = False

    # determine max read gap
    self._readgap = 0
    for (r1, r2) in util.pairwise(self.reads_list):
      assert r1.pos <= r2.pos
      
      gap = r2.pos - r1.aend
      if gap > self._readgap:
        self._readgap = gap

    return 

  def isMemberRead(
    self,
    read,
  ):
    raise NotImplementedError()

  # FIXME this is needed only for moleculo, but it shouldn't be abstract
  def trimReads(self):
    return ([], [])

  def addAlnRead(
    self,
    read,
    readRname,
  ):

    # update running read gap
    if len(self.reads_list) != 0:
      lastRead = self.reads_list[-1]
      self._readgap = max(
        self._readgap,
        read.pos - lastRead.pos,
      )

    rid = (read.qname, read.is_read1)
    self.reads_list.append(read)
    self.rid_set.add(rid)
    self.qname_set.add(read.qname)
    self._numreads += 1

    # NOTE all reads must belong to same chromosome and if the read is
    # being added we can assume this is the case
    self.chrName = readRname

    # first read
    if self.minpos == None:
      self.minpos = read.pos
      self.maxpos = read.aend

    # update read
    else:
      # check going in coordinate sorted order
      assert read.pos >= self.minpos
      self.maxpos = read.aend

  def dropReads(self):
    del self.reads_list
    del self.rid_set
    del self.qname_set

  def getReads(self):
    # require classify() to be called
    assert self.invalid != None
    return iter(self.reads_list)

  def getSize(self):
    assert None not in [self.maxpos, self.minpos]
    return self.maxpos - self.minpos

  def filtQnames(self, qnames):
    qname_set = set(qnames)
    self.reads_list = filter(
      lambda(r): r.qname in qname_set,
      self.reads_list,
    )
    self.rid_set = set(map(
      lambda(r): (r.qname, r.is_read1),
      self.reads_list,
    ))
    self.minpos = min(map(lambda(r): r.pos, self.reads_list))
    self.maxpos = max(map(lambda(r): r.aend, self.reads_list))
    self.classify()

#=========================================================================
# 10x cloud
#=========================================================================
class TenXCloud(Cloud):

  @property
  def MIN_NUM_READS(self): return 6
  @property
  def MAX_READ_GAP(self): return 12000
  #def MAX_READ_GAP(self): return 30000
  @property
  def MIN_DENSITY(self): raise NotImplementedError()

  def __init__(self):
    super(TenXCloud, self).__init__()

  def isMemberRead(
    self,
    read,
  ):

    # singleton reads are by default members
    propNumReads = 1 + len(self.reads_list)
    if propNumReads == 1:
      return True

    # require same contig
    if read.tid != self.reads_list[0].tid:
      return False


    lastRead = self.reads_list[-1]
    propReadgap = read.pos - lastRead.pos
    propT = max(1, read.pos - self.minpos)
    propDensity = 1. * propNumReads / propT

    isMember = (
      #propDensity >= self.MIN_DENSITY and 
      propReadgap < self.MAX_READ_GAP
    )

    return isMember

#=========================================================================
# moleculo cloud 
#=========================================================================
class MoleculoCloud(Cloud):

  # cloud rules
  @property
  def MIN_NUM_READS(self): return 7
  @property
  def MAX_READ_GAP(self): 
    # FIXME remove only for stencil generation
    #return 2500
    return 3500
  @property
  def MIN_DENSITY(self): 
    # FIXME remove only for stencil generation
    #return 1. / 1000
    return 1. / 1500
  @property
  def MAX_BARCODE_GAP(self): return 50

  @staticmethod
  def isBarcode(read):
    return (
      'STRIPPED' in read.qname or
      read.qname.endswith('_bar')
    )

  def __init__(self):

    super(MoleculoCloud, self).__init__()

    # cloud barocde information (populated when classified called)
    self.mixed = None
    self.lpart = None
    self.rpart = None

    # number of member reads trimmed using barcode information
    self.numTrimmed = None

  def isMemberRead(
    self,
    read,
  ):

    # singleton reads are by default members
    propNumReads = 1 + len(self.reads_list)
    if propNumReads == 1:
      return True

    # require same contig
    if read.tid != self.reads_list[0].tid:
      return False

    lastRead = self.reads_list[-1]
    propReadgap = read.pos - lastRead.pos
    propT = max(1, read.pos - self.minpos)
    propDensity = 1. * propNumReads / propT

    # close if largest gap seen is between two different barcodes
    twoCloudSplit = (
      read.qname != lastRead.qname and
      MoleculoCloud.isBarcode(read) and
      MoleculoCloud.isBarcode(lastRead) and
      propReadgap > self.readgap and 
      read.qname not in self.qname_set and
      propReadgap > 300
    )
    isMember = (
      propDensity >= self.MIN_DENSITY and 
      propReadgap < self.MAX_READ_GAP
    )
    #if isMember and twoCloudSplit:
    #  print 'split! {0}:{1}-{2}'.format(
    #    self.chrName,
    #    self.minpos,
    #    self.maxpos,
    #  )

    return isMember and not twoCloudSplit

  def classify(self):

    super(MoleculoCloud, self).classify()

    (
      lbQname_set,
      rbQname_set,
      mixedQname_set,
    ) = self.__getBarcodePlacement__()

    self.mixed = len(mixedQname_set) > 0
    self.lpart = len(lbQname_set) > 0
    self.rpart = len(rbQname_set) > 0

    return 

  def __getBarcodePlacement__(self):

    # define metric for read being close enough to the end of a cloud
    def atLeftEnd(r):
      return (abs(r.pos - self.minpos) <= self.MAX_BARCODE_GAP)
    def atRightEnd(r):
      return (abs(r.aend - self.maxpos) <= self.MAX_BARCODE_GAP)

    # maintain list of barcode reads sorted by ascended position
    barcodeReads_list = filter(
      lambda(read): MoleculoCloud.isBarcode(read),
      self.reads_list,
    )
    allbQname_set = set(map(
      lambda(r): r.qname,
      barcodeReads_list,
    ))

    # seperate into {left, right, mixed} barcode reads
    lbQname_set = set(map(
      lambda(r): r.qname,
      filter(
        lambda(r): atLeftEnd(r),
        barcodeReads_list,
      ),
    ))
    rbQname_set = set(map(
      lambda(r): r.qname,
      filter(
        lambda(r): atRightEnd(r),
        barcodeReads_list,
      ),
    ))
    mixedQname_set = set(filter(
      lambda(qname): qname not in (lbQname_set | rbQname_set),
      allbQname_set,
    ))

    return (
      lbQname_set,
      rbQname_set,
      mixedQname_set,
    )

  def trimReads(self):

    self.numTrimmed = 0
    (
      lbQname_set,
      rbQname_set,
      mixedQname_set,
    ) = self.__getBarcodePlacement__()

    # only trim if some barcodes are questionably mixed
    if len(mixedQname_set) == 0:
      return ([], [])

    def trim(s, e):
      trimReads_list = self.reads_list[s:e]
      trimRid_set = set(map(
        lambda(r): (r.qname, r.is_read1),
        trimReads_list,
      ))
      self.rid_set -= trimRid_set
      del self.reads_list[s:e]
      return trimReads_list

    # trim left side
    seenQname_set = set()
    ltrimIdx = None
    for (i, (r1, r2)) in enumerate(util.pairwise(self.reads_list)):
      gap = abs(r1.pos - r2.pos)
      seenQname_set.add(r1.qname)
      if MoleculoCloud.isBarcode(r1) or len(seenQname_set) > 2:
        break
      elif MoleculoCloud.isBarcode(r2) and gap > 1000:
        ltrimIdx = i + 1

    ltrimReads_list = []
    if ltrimIdx:
      ltrimReads_list = trim(0, ltrimIdx)
      self.numTrimmed += len(ltrimReads_list)

    # trim right side
    seenQname_set = set()
    rtrimIdx = None
    for (i, (r1, r2)) in enumerate(util.pairwise(reversed(self.reads_list))):
      gap = abs(r1.aend - r2.aend)
      seenQname_set.add(r1.qname)
      if MoleculoCloud.isBarcode(r1) or len(seenQname_set) > 2:
        break
      elif MoleculoCloud.isBarcode(r2) and gap > 1000:
        rtrimIdx = len(self.reads_list) - 2 - i

    rtrimReads_list = []
    if rtrimIdx:
      rtrimReads_list = trim(rtrimIdx, len(self.reads_list))
      self.numTrimmed += len(ltrimReads_list)

    if ltrimIdx or rtrimIdx:
      self.minpos = min(map(lambda(r): r.pos, self.reads_list))
      self.maxpos = max(map(lambda(r): r.aend, self.reads_list))

    return (ltrimReads_list, rtrimReads_list)


