import os
import subprocess
import shlex
import logging
import pysam
from collections import defaultdict

from ..common import util

import picard

logger = logging.getLogger(__name__)

toolDir_path = os.path.dirname(os.path.abspath(__file__))
b2AlignScript_path = os.path.join(
  toolDir_path,
  'bowtie2.align.sh',
)

#-------------------------------------------------------------------------
# generate alignments
#-------------------------------------------------------------------------
def genAlignments(
  index_path,
  fq1_path,
  fq2_path,
  outBam_path,
  alignFlags='',
):

  index_path = os.path.join(os.getcwd(), index_path)
  fq1_path = os.path.join(os.getcwd(), fq1_path)
  fq2_path = os.path.join(os.getcwd(), fq2_path)
  outBam_path = os.path.join(os.getcwd(), outBam_path)
  # strip .bam suffix
  if outBam_path.endswith('.bam'):
    outBamPrefix = outBam_path[:-4]
  else:
    outBamPrefix = outBam_path
  
  # perform alignment inside temp directory
  tmpdir = util.mktmpdir()
  logger.info('launching bowtie2 alignment in {0}'.format(tmpdir))
  with util.cd(tmpdir):
    cmd = '{0} {1} {2} {3} {4} "{5}"'.format(
      b2AlignScript_path,
      index_path,
      fq1_path,
      fq2_path,
      outBamPrefix,
      alignFlags,
    )
    logger.info('  - cmd: {0}'.format(cmd))
    pp = subprocess.Popen(
      shlex.split(cmd),
      stdout=subprocess.PIPE,
      stderr=subprocess.PIPE,
    )
    pp.wait()
    logger.debug(pp.stdout.read())
    logger.debug(pp.stderr.read())
    assert pp.returncode == 0

#-------------------------------------------------------------------------
# create index from *fasta
#-------------------------------------------------------------------------
def genAlignmentsMP(
  index_path,
  fq1_path,
  fq2_path,
  outBam_path,
  alignFlags='',
  numPrePasses=3,
):

  tmpdir = util.mktmpdir()

  alignFlags += ' --non-deterministic'
  logger.info('launching {0}-pass bowtie2 alignment in {1}'.format(
    numPrePasses,
    tmpdir))

  if os.path.dirname(outBam_path) == '':
    outBam_path = os.path.join('..', outBam_path)

  with util.cd(tmpdir):

    multiBam_path = 'multi.bam'

    preBam_list = []
    # align multiple times to distribute multimapped reads across possible
    # best locations
    logger.info(' - {0} pre passes'.format(numPrePasses))
    for i in xrange(numPrePasses):
      preBam_path = 'pre_{0}.bam'.format(i)
      preBam_list.append(preBam_path)
      genAlignments(
        index_path,
        fq1_path,
        fq2_path,
        preBam_path,
        alignFlags=alignFlags,
      )

    # merge pre trials
    picard.mergeBams(
      preBam_list,
      multiBam_path,
    )
    picard.createBAMIndex(multiBam_path)

    # eliminate duplicate alignments for each read
    logger.info(' - eliminiate duplicate alignments')

    # determine reads never mapped
    logger.info('   - determine reads never mapped')
    ridMapped_set = set()
    ridUnmapped_set = set()
    multi_fhandle = pysam.Samfile(multiBam_path, 'rb')
    for read in multi_fhandle.fetch(until_eof=True):
      rid = (read.qname, read.is_read1)
      if read.is_unmapped:
        ridUnmapped_set.add(rid)
      elif not read.is_secondary:
        ridMapped_set.add(rid)
    multi_fhandle.close()
    ridUnmapped_set -= ridMapped_set

    logger.info('   - create single mapped bam')
    seenRid_set = set()

    def getFlags(mapped, pmapped):
      flags = 0xfff
      if not mapped or not pmapped:
        flags &= 0xffd

      if pmapped:
        flags &= 0xff7
      else:
        flags |= 0x8
    
      return flags

    out_fhandle = pysam.Samfile(outBam_path, 'wb', template=multi_fhandle)
    multi_fhandle = pysam.Samfile(multiBam_path, 'rb')
    ridPosSet_map = defaultdict(set)
    for read in multi_fhandle.fetch(until_eof=True):
      rid = (read.qname, read.is_read1)
      # skip already aligned
      if (
        read.is_unmapped and 
        rid not in seenRid_set and 
        rid in ridUnmapped_set
      ):
        # set pair flag
        prid = (rid[0], not rid[1])
        pairMapped = prid in ridMapped_set
        read.flag &= getFlags(False, pairMapped)
        seenRid_set.add(rid)
        out_fhandle.write(read)
      # skip secondary
      elif read.is_secondary:
        pass
      elif (
        not read.is_unmapped and
        rid in ridMapped_set and 
        read.pos not in ridPosSet_map[rid]
      ):
        #  set pair flag
        prid = (rid[0], not rid[1])
        pairMapped = prid in ridMapped_set
        read.flag &= getFlags(True, pairMapped)
        ridPosSet_map[rid].add(read.pos)
        seenRid_set.add(rid)
        out_fhandle.write(read)

    multi_fhandle.close()
    out_fhandle.close()

    # reindex new *bam
    picard.createBAMIndex(outBam_path)

    logger.info(' - done')

#-------------------------------------------------------------------------
# create index from *fasta
#-------------------------------------------------------------------------
def createFMIndex(fasta_path):
  cmd = 'bowtie2-build {0} {0}'.format(fasta_path)
  logger.info('launching bowtie2 index generation')
  logger.info('  - cmd: {0}'.format(cmd))
  pp = subprocess.Popen(
    shlex.split(cmd),
    stdout=subprocess.PIPE,
  )
  pp.wait()
  assert pp.returncode == 0

