import logging
import pysam
from collections import namedtuple, defaultdict
from itertools import izip

from ..common import util

from ..align.candread import CandAlignment
from ..toolflow import samtools

logger = logging.getLogger(__name__)

#=========================================================================
# pass through all reads and build mapq
#=========================================================================
def getReadQualityInfo(bam_path):

  # obtain index stats so we have total number of reads
  info_map = samtools.getIndexStats(bam_path)
  key = 'sum'
  numReads = info_map[key][1] + info_map[key][2]
  step = max(1, numReads / 10)

  readsFile = pysam.Samfile(bam_path, 'rb')
  read_map = defaultdict(set)
  logger.info('getting read qualities from: {0}'.format(bam_path))
  for (i, read) in enumerate(readsFile.fetch(until_eof=True)):
    if i % step == 0:
      logger.info('  - passed {0},{1}'.format(i, numReads))

    rid = (read.qname, read.is_read1)
    if read.is_unmapped:
      read_map[rid] = None
    else:
      aux = getReadAuxInfo(read)
      chrName = readsFile.getrname(read.tid)
      correct = CandAlignment.__isTrue__(read, chrName)
      read_map[rid].add(
        (aux.ED, read.mapq, correct)
      )

  return read_map

#=========================================================================
# obtain read edits
#=========================================================================
ctype_map = {
  'M' : 0 , #BAM_CMATCH      
  'I' : 1 , #BAM_CINS        
  'D' : 2 , #BAM_CDEL        
  'N' : 3 , #BAM_CREF_SKIP   
  'S' : 4 , #BAM_CSOFT_CLIP  
  'H' : 5 , #BAM_CHARD_CLIP  
  'P' : 6 , #BAM_CPAD        
  '=' : 7 , #BAM_CEQUAL      
  'X' : 8 , #BAM_CDIFF       
}

def getEdits(read, chrName, ref_map):

  edits_map = {}

  refSeq = ref_map[chrName][read.pos:read.aend]
  # first find snp edits
  refPos_list = read.get_reference_positions(full_length=True)
  queryPos_list = range(read.qlen)
  assert len(refPos_list) == len(queryPos_list)
  for (rpos, qpos) in izip(refPos_list, queryPos_list):
    # skipped unaligned bases in the read
    if rpos == None:
      continue

    ridx = rpos - (refSeq.start - 1)
    rbase = str(refSeq[ridx]).lower()
    qbase = read.seq[qpos].lower()

    if rbase != qbase:
      edits_map[rpos] = qbase

  ins_list = filter(lambda(c,_): c == 1, read.cigar)
  del_list = filter(lambda(c,_): c == 2, read.cigar)
  # pop() for list pop's the end, so reverse
  ins_list.reverse()
  del_list.reverse()

  indel = (max(len(ins_list), len(del_list)) > 0)
  if indel:
    for ((_, begin), (end, _)) in util.pairwise(read.get_blocks()):
      # insertion
      if begin == end:
        (_, size) = ins_list.pop()
        prefix = '+'
      # deletion
      else:
        (_, size) = del_list.pop()
        prefix = '-'
      edits_map[begin] = prefix + str(size)

  # ensure processed all insertions and deletions
  assert len(ins_list) == 0
  assert len(del_list) == 0

  return edits_map

#=========================================================================
# populate read aux info from tags
#=========================================================================
ReadAuxInfo_t = namedtuple(
  'ReadAuxInfo',
  [
    # both
    'type',
    'ED',

    # bowtie only
    # alignment score
    'AS',
    # alignment score of pair
    'PAS',
    # num mismatching bp
    'numBpMismatch',

    # bwa only
    'numOpt',
    'numSubOpt',
  ],
  #verbose=True,
)

def getReadAuxInfo(read):

  # scan through aux read tags for more information
  readType = None 
  numOpt = 1
  numSubOpt = 0

  ED = None
  AS = None
  PAS = None
  numBpMismatch = None

  # scan through aux read tags for more information
  for (tag, val) in read.tags:
    if tag == 'XT':
      readType = val
      # mate rescued mappings don't have alternate alignments
      if readType in ['M','N']:
        numOpt = 1
        numSubOpt = 0
    if tag == 'X0':
      numOpt = val
    if tag == 'X1':
      numSubOpt = val
    if tag == 'NM' :
      ED = val

    if tag == 'AS':
      AS = val
    if tag == 'YT':
      readType = val
    if tag == 'YS':
      PAS = val
    if tag == 'NM' :
      ED = val
    if tag == 'XM' :
      numBpMismatch = val

  return ReadAuxInfo_t(
    type=readType,
    numOpt=numOpt,
    numSubOpt=numSubOpt,
    ED=ED,
    AS=AS,
    PAS=PAS,
    numBpMismatch=numBpMismatch,
  )

