import pysam
import numpy as np
from itertools import izip, ifilter
from collections import Counter, defaultdict, namedtuple
from bx.intervals.intersection import IntervalTree

from .candread import CandAlignment

#=========================================================================
# cloud features tuple
#=========================================================================
CloudFeat_t = namedtuple(
  'CloudFeat_t',
  [
    'size',
    'begin',
    'end',

    'numReads',
    # fixed reads
    'numFixed',
    # active multimapped reads
    'numCand',
    # {left, right} barcode read present
    'lbarValid',
    'rbarValid',
    # number of barcodes not assigned to the ends
    'numMixedBar',
    # cloud is active
    'isActive',
    # position of edits
    # <pos> : Counter({acgt})
    'edits_map',
  ]
)

def getFeatsStr(cloudFeats):
  if cloudFeats.lbarValid and cloudFeats.rbarValid:
    btype = 'bothbar'
  elif cloudFeats.lbarValid ^ cloudFeats.rbarValid:
    btype = 'pbar'
  else:
    btype = 'nobar'
  return '(s:{0}, n:{1}, {2}, act:{3})'.format(
    cloudFeats.size,
    cloudFeats.numReads,
    btype,
    cloudFeats.isActive
  )

#=========================================================================
# candidate cloud
#=========================================================================
class CandCloud(object):

  def __init__(
    self,
    cid,
    chrName,
    fixedMappings,
    candMappings,
  ):

    self.MAX_BARCODE_GAP = 50

    self.cid = cid
    self.chrName = chrName

    self.fixedMappings = set(fixedMappings)
    self.candMappings = set(candMappings)
    self.candRid_set = set(map(
      lambda(loc): loc.rid,
      self.candMappings,
    ))

    # obtain static sets of all reads potentially supporting ref, alt
    # alleles at all candidate edit sites
    allMappings = self.fixedMappings | self.candMappings
    self.siteEditReads_map = defaultdict(set)
    self.siteRefReads_map = defaultdict(set)
    self.candEditSites_set = set()
    readIntervals = IntervalTree()
    for loc in allMappings:
      # collect all possible alignments supporting an edit
      for editSite in loc.edits_map:
        self.siteEditReads_map[editSite].add(loc)
        self.candEditSites_set.add(editSite)
      # collect locations of all reads 
      # NOTE add 1bp slop to read so (pos, pos) hits it on last bases
      # NOTE aend is position the last aligned base + 1
      readIntervals.insert(loc.pos-1, loc.aend, loc)

    for editSite in self.candEditSites_set:
      # find all reads overlapping this edit
      hits = readIntervals.find(editSite, editSite)
      for loc in hits:
        # add this read as a reference read if it does not display an edit
        if editSite not in loc.edits_map:
          self.siteRefReads_map[editSite].add(loc)

    del readIntervals

    # save cheat values to speed up feature extraction
    self.minpos = None
    self.maxpos = None

    if (
      len(self.fixedMappings) > 0 and
      len(self.candMappings) > 0
    ):
      cand__minpos  = min(map(lambda(loc): loc.pos, self.candMappings))
      fixed__minpos = min(map(lambda(loc): loc.pos, self.fixedMappings))
      self.minpos = fixed__minpos if fixed__minpos < cand__minpos else None

      cand__maxpos  = max(map(lambda(loc): loc.aend, self.candMappings))
      fixed__maxpos = max(map(lambda(loc): loc.aend, self.fixedMappings))
      self.maxpos = fixed__maxpos if fixed__maxpos > cand__maxpos else None

  def getActiveCands(self, mappings):
    activeMappings = set()
    for rid in self.candRid_set:
      if rid not in mappings:
        continue
      loc = mappings[rid]
      if loc in self.candMappings:
        activeMappings.add(loc)
    return activeMappings

  def getActiveMappings(self, mappings):
    return (
      self.getActiveCands(mappings) |
      self.fixedMappings
    )

  def getFeats(self, mappings=None):
    # obtain mappings active in this cloud
    if mappings:
      activeCands = self.getActiveCands(mappings)
    else:
      activeCands = self.candMappings

    activeMappings = activeCands | self.fixedMappings

    if len(activeMappings) > 0:
      if self.minpos != None:
        minpos = self.minpos
      else:
        minpos = min(map(lambda(loc): loc.pos, activeMappings))
      if self.maxpos != None:
        maxpos = self.maxpos
      else:
        maxpos = max(map(lambda(loc): loc.aend, activeMappings))
      size = maxpos - minpos
      assert maxpos >= minpos
    else:
      minpos = None
      maxpos = None
      size = None
    
    # determine positions of barcode reads and barcode type
    (lbarMappings, rbarMappings, mbarMappings) = \
      self.__getBarcodePlacement__(activeMappings, minpos, maxpos)
    lbarValid = len(lbarMappings) > 0
    rbarValid = len(rbarMappings) > 0
    numMixedBar = len(mbarMappings)

    # determine edit positions
    edits_map = self.getEditsCounterMap(activeMappings)

    numCand  = len(activeCands)
    numFixed = len(self.fixedMappings)
    numReads = numCand + numFixed
    isActive = not ((numFixed <= 2) and (numCand <= 2))
    #isActive = not ((numFixed <= 12) and (numCand <= 6))
    return CloudFeat_t(
      size=size,
      begin=minpos,
      end=maxpos,
      numReads=numReads,
      numFixed=numFixed,
      numCand=numCand,
      lbarValid=lbarValid,
      rbarValid=rbarValid,
      numMixedBar=numMixedBar,
      isActive=isActive,
      edits_map=edits_map,
    )

  def getCandRids(self):
    return iter(self.candRid_set)

  def getCandMappings(self):
    return iter(self.candMappings)

  def getFixedMappings(self):
    return iter(self.fixedMappings)

  def getAllMappings(self):
    return iter(self.candMappings | self.fixedMappings)

  def getEditsCounterMap(self, activeMappings):
    # iterate through all edit sites and stack up active read counts
    # supporting ref, alt alleles
    edits_map = defaultdict(Counter)
    for editSite in self.candEditSites_set:
      altReads = activeMappings & self.siteEditReads_map[editSite]
      refReads = activeMappings & self.siteRefReads_map[editSite]
      #if len(altReads) == 0:
      #  continue

      edits_map[editSite]['ref'] = len(refReads)
      for loc in altReads:
        altBase = loc.edits_map[editSite]
        edits_map[editSite][altBase] += 1

    return edits_map

  def getEditsReadsMap(self, activeMappings):
    # iterate through all edit sites and stack up active read counts
    # supporting ref, alt alleles
    edits_map = defaultdict(set)
    for editSite in self.candEditSites_set:
      edits_map[editSite] |= activeMappings & self.siteEditReads_map[editSite]
      edits_map[editSite] |= activeMappings & self.siteRefReads_map[editSite]

    return edits_map
    
  def getMixedBarcodes(self, mappings):
    # obtain mappings active in this cloud
    candMappings = self.getActiveCands(mappings) 
    activeMappings = candMappings | self.fixedMappings

    minpos = min(map(lambda(loc): loc.pos, activeMappings))
    maxpos = max(map(lambda(loc): loc.aend, activeMappings))

    # obtain barcode candidates indicative of mixed clouds
    (_, _, mbarMappings) = \
      self.__getBarcodePlacement__(candMappings, minpos, maxpos)

    return mbarMappings

  def __getBarcodePlacement__(
    self,
    activeMappings,
    minpos,
    maxpos,
  ):
    # define metric for read being close enough to the end of a cloud
    # FIXME should be factored out
    self.MAX_BARCODE_GAP = 50
    def atLeftEnd(r):
      return (abs(r.pos - minpos) <= self.MAX_BARCODE_GAP)
    def atRightEnd(r):
      return (abs(r.aend - maxpos) <= self.MAX_BARCODE_GAP)

    barcodeCands_list = filter(lambda(loc): loc.isBarcode, activeMappings)
    allbQname_set = set(map(
      lambda(loc): loc.qname,
      barcodeCands_list,
    ))

    # seperate into {left, right, mixed} barcode reads
    lbQname_set = set(map(
      lambda(r): r.qname,
      filter(
        lambda(r): atLeftEnd(r),
        barcodeCands_list,
      ),
    ))
    rbQname_set = set(map(
      lambda(r): r.qname,
      filter(
        lambda(r): atRightEnd(r),
        barcodeCands_list,
      ),
    ))
    mixedQname_set = set(filter(
      lambda(qname): qname not in (lbQname_set | rbQname_set),
      allbQname_set,
    ))

    lbMappings = filter(
      lambda(loc): loc.qname in lbQname_set,
      barcodeCands_list,
    )
    rbMappings = filter(
      lambda(loc): loc.qname in rbQname_set,
      barcodeCands_list,
    )
    mbMappings = filter(
      lambda(loc): loc.qname in mixedQname_set,
      barcodeCands_list,
    )

    return (
      lbMappings,
      rbMappings,
      mbMappings,
    )

