import os
import subprocess
import shlex
import pysam
import logging
from collections import Counter, defaultdict

from  .read import getReadAuxInfo

from ..toolflow import picard
from ..align.candread import CandAlignment
from ..common import util

from multiprocessing import Process


logger = logging.getLogger(__name__)

#-------------------------------------------------------------------------
# filter specified rid set
#-------------------------------------------------------------------------
def filterRids(
  inputBam_path,
  outputBam_path,
  rid_set,
  complement=False,
  tag=None,
):
  assert tag == None or type(tag) == type('')

  logger.info('filtering rids from {0}'.format(inputBam_path))
  in_fhandle = pysam.Samfile(inputBam_path, 'rb')
  out_fhandle = pysam.Samfile(
    outputBam_path,
    'wb',
    template=in_fhandle,
  )

  for read in in_fhandle:
    rid = (read.qname, read.is_read1)
    passy = (
      (complement and rid not in rid_set) or
      (not complement and rid in rid_set)
    )
    if passy:
      if tag:
        oldtags = list(read.tags)
        newtags = [
          ('AF', tag),
        ]
        newtags.extend(oldtags)
        read.tags = newtags
      out_fhandle.write(read)

  in_fhandle.close()
  out_fhandle.close()

  logger.info('  - created {0}'.format(outputBam_path))

  picard.createBAMIndex(outputBam_path)

  return

#-------------------------------------------------------------------------
# filter duplicate alignments
#-------------------------------------------------------------------------

def filterDuplicates(
  inputBam_path,
  outputBam_path,
):
  p = Process(
    target=__filterDuplicates__,
    args=[
      inputBam_path,
      outputBam_path,
    ],
  )
  p.start()
  p.join()

def __filterDuplicates__(
  inputBam_path,
  outputBam_path,
):

  def getAS(read):
    aux = getReadAuxInfo(read)
    return aux.AS

  def isPrimary(read):
    return (not read.is_secondary)
    
  # pass through once to obtain locations of multimapped rids
  logger.info('determining set of reads multimapped to same location in {0}'.format(
    inputBam_path  
  ))
  input_fhandle = pysam.Samfile(inputBam_path, 'rb')
  ridMultiLocs_map = defaultdict(set)

  for read in input_fhandle.fetch(until_eof=True):
    if read.is_unmapped:
      continue
    rid = (read.qname, read.is_read1)
    chrName = input_fhandle.getrname(read.tid)
    startPos = (chrName, read.pos)

    ridMultiLocs_map[(rid, startPos)].add(
      (getAS(read), isPrimary(read)) 
    )
  input_fhandle.close()

  # filter for only multimapped to same location
  ridMultiLocs_map = dict(filter(
    lambda(_, pos_set): len(pos_set) > 0,
    ridMultiLocs_map.items(),
  ))

  logger.info('  - dumping filtered to {0}'.format(outputBam_path))
  # choose best multimapped aligned read
  ridMultiChosen_map = {}

  for (key, read_set) in ridMultiLocs_map.items():
    (rid, startPos) = key
    hasPrimary = True in map(
      lambda(AS, prim): prim,
      read_set,
    )
    bestRead = sorted(
      read_set, 
      reverse=True,
    )[0]

    assert (
      # either it has no primary alignment here
      not hasPrimary or
      # either it has no primary alignment here
      hasPrimary and bestRead[1]
    )

    ridMultiChosen_map[key] = bestRead

  # choose selected alignment
  template_fhandle = pysam.Samfile(inputBam_path, 'rb')
  output_fhandle = pysam.Samfile(
    outputBam_path,
    'wb', 
    template=template_fhandle,
  )

  handledKey_set = set()
  for read in template_fhandle.fetch(until_eof=True):

    # write ones which are just unmapped
    if read.is_unmapped:
      output_fhandle.write(read)
      continue

    rid = (read.qname, read.is_read1)
    chrName = template_fhandle.getrname(read.tid)
    startPos = (chrName, read.pos)
    key = (rid, startPos)

    # write ones which are not multimapped at same location
    if key not in ridMultiChosen_map:
      output_fhandle.write(read)
      continue
    # skip ones which are already handled
    if key in handledKey_set:
      continue

    akey = (getAS(read), isPrimary(read))

    # write if this alignment has desired qualities at that location
    if akey == ridMultiChosen_map[key]:
      output_fhandle.write(read)
      handledKey_set.add(key)

  assert len(handledKey_set) == len(ridMultiChosen_map)

  template_fhandle.close()
  output_fhandle.close()
  
  picard.createBAMIndex(outputBam_path)
  ridMultiLocs_map.clear()
  ridMultiChosen_map.clear()
  handledKey_set.clear()

  return

#-------------------------------------------------------------------------
# generate filtered bam from input multi candidate alignment bam
#-------------------------------------------------------------------------
def genFilteredBam(
  inputBam_path,
  outputBam_path,
  ridMappings_map,
  ridMapq_map,
  excludeRid_set=None,
  tagRid_map=None,
):

  template_fhandle = pysam.Samfile(inputBam_path, 'rb')
  output_fhandle = pysam.Samfile(
    outputBam_path,
    'wb', 
    template=template_fhandle,
  )
  singleRid_set = set()
  multiRidCounter = Counter()

  def getPairStr(loc, mateLoc):
    # correct pair tag
    if mateLoc and mateLoc.cid == loc.cid:
      pairStr = 'CP'
    elif mateLoc and mateLoc.cid != loc.cid:
      pairStr = 'DP'
    elif mateLoc == None:
      pairStr = 'UP'
    else:
      die
    return pairStr

  def getInsertSize(loc, mateLoc):
    if mateLoc == None or loc.cid != mateLoc.cid:
      return 0
    # FIXME something is clearly wrong with this...
    sign = 1.0 if loc.pos < mateLoc.pos else -1.0
    isize = (
      max(loc.aend, mateLoc.aend) - 
      min(mateLoc.pos, loc.pos)
    )
    return sign * isize

  for read in template_fhandle.fetch(until_eof=True):
    # skip all reads which were unmapped in pass2 as these will be
    # included in the bwa realignment step
    if read.is_unmapped:
      continue

    rid = (read.qname, read.is_read1)

    # skip all reads which are to be excluded from the *bam
    if excludeRid_set and rid in excludeRid_set:
      continue

    loc = None if rid not in ridMappings_map else ridMappings_map[rid]
    fwd = read.is_read1
    mateRid = (read.qname, not fwd)
    mateLoc = None if mateRid not in ridMappings_map else ridMappings_map[mateRid]

    chrName = template_fhandle.getrname(read.tid)
    # require all rids which do not appear in ridMappings to appear only
    # once in template.  if this is not the case then something is wrong
    # pass reads uniquely mapped on to output bam
    if loc == None:
      #assert rid not in singleRid_set
      if rid not in singleRid_set and not read.is_secondary:
        singleRid_set.add(rid)
        # set mapq255 reads from bowtie2 to be a valid 60
        if read.mapq == 255:
          read.mapq = 60
          # update tags
          if tagRid_map and rid in tagRid_map:
            oldtags = list(read.tags)
            newtags = [
              ('AF', 'fixed'),
              ('AB', tagRid_map[rid]),
            ]
            newtags.extend(oldtags)
            read.tags = newtags
        output_fhandle.write(read)
    # add this read with modified mapq if it corresponds to specified
    # mapping
    elif (chrName, read.pos) == (loc.chrName, loc.pos):

      read.mapq = ridMapq_map[rid]
      read.tlen = getInsertSize(loc, mateLoc)
      pairStr = getPairStr(loc, mateLoc)
      oldtags = filter(
        lambda(s): s[0] != 'YT',
        read.tags,
      )
      newtags = [('AF', 'realigned')]
      if tagRid_map and rid in tagRid_map:
        newtags.append(('AB', tagRid_map[rid]))
      newtags.extend(oldtags)
      newtags.append( ('YT', pairStr) )
      read.tags = newtags
      read.is_secondary = False
      if rid not in multiRidCounter:
        output_fhandle.write(read)
      elif multiRidCounter[rid] == 1:
        logger.debug('WARNING more than one read included for rid: {0}'.format(rid))
      multiRidCounter[rid] += 1

  template_fhandle.close()
  output_fhandle.close()
  picard.createBAMIndex(outputBam_path)

  return

#-------------------------------------------------------------------------
# generate filtered bam from input multi candidate alignment bam
#-------------------------------------------------------------------------
def convertBamCoords(
  inputBam_path,
  seqCoord_map,
  templateBam_path,
  outputBam_path,
):

  # pysam is not capable of manipulating reads and writing the reassigned
  # positions to a file so use a template bam file to obtain the header
  # for the genome of interest and dump out a *sam txt file to be
  # sorted and converted to *bam

  # obtain header file text from template *bam of genome to translate to
  readsFile = pysam.Samfile(templateBam_path, 'rb')
  headerText = readsFile.text

  # must invoke samtools directly to get raw read string
  inputBamDir_path = os.path.dirname(inputBam_path)
  inputBam_fname = os.path.basename(inputBam_path)
  tmp_fname = 'tmp' + inputBam_fname
  tmp_path = os.path.join(
    inputBamDir_path,
    tmp_fname,
  )
  cmd = 'samtools view {0}'.format(inputBam_path)
  with open(tmp_path,'w') as f:
    pp = subprocess.Popen(
      shlex.split(cmd),
      stdout=f,
    )
    pp.wait()

  outputPrefix = outputBam_path[:-4]
  outputSam_fname = outputPrefix + '.sam'

  with open(tmp_path) as fin, \
       open(outputSam_fname, 'w') as fout:
    fout.write(headerText)
    for line in fin:
      (
        qname_str,
        flags_str,
        rname_str,
        pos_str,
        mapq_str,
        cigar_str,
        pair_rname_str,
        pair_pos_str,
        remains_str,
      ) = line.split('\t', 8)

      flags = int(flags_str)

      UNMAPPED_FLAG = 0x4

      # skip unmapped reads
      #if flags & UNMAPPED_FLAG:
      #  continue

      if rname_str == '*':
        (newRname, posBeginOffset, _) = ('*', 0, 0)
      else:
        (newRname, posBeginOffset, _) = seqCoord_map[rname_str]

      if pair_rname_str == '=':
        (pair__newRname, pair__posBeginOffset, _) = seqCoord_map[rname_str]
        pair__newRname = '='
      elif pair_rname_str == '*':
        (pair__newRname, pair__posBeginOffset, _) = ('*', 0, 0)
      else:
        (pair__newRname, pair__posBeginOffset, _) = seqCoord_map[pair_rname_str]
        if pair__newRname == newRname:
          pair__newRname = '='

      newSam_str = '%s\t' * 9 % \
       (
         qname_str,
         flags_str,
         newRname,
         (posBeginOffset) + int(pos_str),
         mapq_str,
         cigar_str,
         pair__newRname,
         (pair__posBeginOffset) + int(pair_pos_str),
         remains_str,
       )

      # eliminate excess tab
      fout.write(newSam_str[:-1])

  # sort and convert to *bam file
  picard.sortSam(outputSam_fname, outputBam_path)
  picard.createBAMIndex(outputBam_path)

#-------------------------------------------------------------------------
# color simulation bam
#-------------------------------------------------------------------------
def colorSimBam(
  inputBam_path,
  tagRid_set=None,
):
  if tagRid_set == None:
    tagRid_set = set()

  dirname = os.path.dirname(inputBam_path)
  bam_fname = os.path.basename(inputBam_path)
  newbam_path = os.path.join(
    dirname, 
    'color__' + bam_fname,
  )

  logger.info('coloring {0}'.format(inputBam_path))
  
  orig_fhandle = pysam.Samfile(inputBam_path, 'rb')
  color_fhandle = pysam.Samfile(
    newbam_path,
    'wb', 
    template=orig_fhandle,
  )
  
  for read in orig_fhandle.fetch():
    tags = filter(
      lambda(s): s[0] != 'RG',
      read.tags,
    )   
    rid = (read.qname, read.is_read1)
    chrName = orig_fhandle.getrname(read.tid)

    if not read.is_unmapped:
      correct = CandAlignment.__isTrue__(read, chrName)
      if correct:
        if rid in tagRid_set:
          newtag = 'correct'
        else:
          newtag = 'correct_rest'
      else:
        if rid in tagRid_set:
          newtag = 'incorrect'
        else:
          newtag = 'incorrect_rest'
    else:
      newtag = 'rest'

    # tag barcodes separately
    if CandAlignment.__isBarcode__(read):
      newtag = 'b_' + newtag
    tags.append( ('RG', newtag) )
    read.tags = tags

    color_fhandle.write(read)

  orig_fhandle.close()
  color_fhandle.close()

  return
  # FIXME need to add RG to header in order for picard to go through
  picard.createBAMIndex(newbam_path)


