import os
import sys
import pysam
from collections import defaultdict
from bx.intervals.cluster import ClusterTree

from mlib.toolflow import ref
from mlib.toolflow import bowtie2
from mlib.toolflow import samtools
from mlib.common import util

from mlib.common import plotter

outDir_path = '/srv/gs1/projects/batzoglou/abishara/data/cancer.validation/ref'
refFasta_path = '/srv/gs1/projects/batzoglou/dkashef/phasing_project/GATK_resources/bundle/2.8/hg19/ucsc.hg19.fasta'
inputDir_path = '/home/abishara/save/primer-save3/deliver'
inputFname_list = filter(
  lambda(x): x.endswith('txt'),
  os.listdir(inputDir_path),
)

#--------------------------------------------------------------------------
# helpers
#--------------------------------------------------------------------------
def getCallInfo(line):
  #chrName  lbegin  lsize lseq  rbegin  rsize rseq  pos altAllele prodSize  floatMatches
  words = line.split('\t')
  chrName = words[0]
  lbegin = int(words[1])
  lsize = int(words[2])
  lseq = words[3]
  rbegin = int(words[4])
  rsize = int(words[5])
  rseq = words[6]
  pos = int(words[7])
  altAllele = words[8]
  prodSize = int(words[9])
  return (
    chrName,
    lbegin,
    lsize,
    rbegin,
    rsize,
    pos,
    altAllele,
    prodSize,
  )

def getSnpCalls():
  snp_set = set()
  for fname in inputFname_list:
    path = os.path.join(inputDir_path, fname)
    with open(path) as f:
      for line in f:
        if line.startswith('#'):
          continue  
        (
          chrName,
          lbegin,
          lsize,
          rbegin,
          rsize,
          pos,
          altAllele,
          prodSize,
        ) = getCallInfo(line)
        snp_set.add(
          ((chrName, pos), altAllele)
        )
  return snp_set

#--------------------------------------------------------------------------
# build reference
#--------------------------------------------------------------------------
def buildRef():

  
  # collect all products
  print 'load primer save files'
  snpCall_set = set()
  uid_map = {}
  coord_map = {}
  regions_map = defaultdict(lambda: ClusterTree(500, 1))
  uid = 0
  for fname in inputFname_list:
    path = os.path.join(inputDir_path, fname)
    with open(path) as f:
      for line in f:
        if line.startswith('#'):
          continue  
        uid += 1
        (
          chrName,
          lbegin,
          lsize,
          rbegin,
          rsize,
          pos,
          altAllele,
          prodSize,
        ) = getCallInfo(line)
      
        snpCall = (chrName, pos)
        rend = rbegin + rsize
        coord_map[uid] = (chrName, lbegin, rend)
        uid_map[uid] = snpCall
        regions_map[chrName].insert(
          lbegin,
          rend,
          uid,
        )

        try:
          assert snpCall not in snpCall_set
        except:
          print 'dup target', snpCall

        snpCall_set.add(snpCall)

  print 'total snps', len(snpCall_set)


  def getUnion(uid_list):
    intervals = map(
      lambda(x): coord_map[x],
      uid_list,
    )
    begin = min(map(
      lambda(chrName, begin, end): begin,
      intervals,
    ))
    end = max(map(
      lambda(chrName, begin, end): end,
      intervals,
    ))
    return (begin, end)

  # union overlaps
  print 'union overlaps'
  coord_list = []
  for chrName, regions in regions_map.items():
    for (_, _, uid_list) in regions.getregions():
      (begin, end) = getUnion(uid_list)
      coord_list.append(
        (chrName, begin, end)
      )

      if len(uid_list) > 1:
        print 'union', coord_list[-1]
        print '  - input', map(
          lambda(x): coord_map[x],
          uid_list,
        )

  abbrevFasta_path = os.path.join(
    outDir_path,
    'validate_ref.fasta',
  )

  print 'build abbrev ref'
  contigCoord_map = ref.buildAbbrevRef(
    coord_list,
    refFasta_path,
    abbrevFasta_path,
    MAX_EXTEND_LEN=500,
  )
  print 'create index'
  bowtie2.createFMIndex(abbrevFasta_path)

  print 'done'

#--------------------------------------------------------------------------
# get annotations
#--------------------------------------------------------------------------
def getAnnot():
  snpCand_set = getSnpCalls()

  def getVcfIter(path):
    with open(path) as f:
       for line in f:
         if line.startswith('#'):
           continue
         words = line.strip().split('\t')
         chrName = words[0]
         pos = int(words[1]) - 1
         altAllele = words[4]
         snpCall = ((chrName, pos), altAllele)
         entries_list = words[-1].split(';')
         info_map = {}
         for entry_str in entries_list:
           (key, val_str) = entry_str.split('=')
           if key == 'type':
             val = str(val_str)
           elif key in ['na', 'nr']:
             val = int(val_str)
           else:
             val = float(val_str)
           info_map[key] = val
         info_map['alt'] = altAllele
         yield (snpCall, info_map)
    raise StopIteration
  
  vcfRootDir_path = '/srv/gs1/projects/batzoglou/abishara/scratch/cancer.rfa.new'
  dirs = filter(
    lambda(f): os.path.isdir(os.path.join(vcfRootDir_path,f)),
    os.listdir(vcfRootDir_path),
  )
  #snpAnnot_map = defaultdict(Counter)
  snpAnnot_map = defaultdict(list)
  print 'passing through all vcfs'
  for (j, dir) in enumerate(dirs):
    #if j > 0:
    #  print 'early exit'
    #  break
    print '{0}/{1} dirs'.format(j, len(dirs))
    vcfFragDir_path = os.path.join(vcfRootDir_path, dir)
    vcfFname_list = filter(
      lambda(f): f.endswith('vcf'),
      os.listdir(vcfFragDir_path),
    )
  
    step = max(1, len(vcfFname_list) / 20)
    for (i, fname) in enumerate(vcfFname_list):
      if i % step == 0:
        print '  - passed {0}/{1}'.format(i, len(vcfFname_list))
      path = os.path.join(vcfFragDir_path, fname)
      for (snpCall, info_map) in getVcfIter(path):
        if snpCall not in snpCand_set:
          continue

        snpAnnot_map[snpCall].append(info_map)


  return snpAnnot_map

#--------------------------------------------------------------------------
# filter snp calls
#--------------------------------------------------------------------------
def filterCalls(snpAnnot_map):
  snpCand_set = getSnpCalls()

  numlowqual = 0
  numvalidate = 0
  numskipped = 0
  validateSnp_set = set()
  CUTOFF = -0.013
  for snpCall in snpCand_set:
    if snpCall not in snpAnnot_map:
      numskipped += 1
      continue
    exclude = 0
    hqcall = 0
    lqcall = 0
    totcall = 0
    lqcloud = 0
    annot_list = snpAnnot_map[snpCall]
    for info_map in annot_list:
      totcall += 1
      hqcall += (
        info_map['prc'] > CUTOFF and
        info_map['pr'] > 0.90 and 
        not info_map['exclude']
      )
      lqcall += (
        info_map['prc'] < CUTOFF and
        info_map['pr'] > 0.90
      )
      lqcloud += (
        info_map['prc'] < CUTOFF
      )
      exclude += (info_map['exclude'])

    validate = (
      hqcall >= 2 #and 
      #1. * lqcall / (lqcall + hqcall) < 0.5
      #1. * exclude / (hqcall + exclude) < 0.2 
      #exclude < 2 
      #hqcall - exclude >= 2
    )
    if validate:
      validateSnp_set.add(snpCall)
      numvalidate += 1

  print 'total variants', len(snpCand_set)
  print 'number to validate', numvalidate 
  print 'number skipped ', numskipped 
  print 'number lowqual', numlowqual


  pileupBed_fname = 'shit/pileup.bed'
  pileupOut_path = 'shit/pileup.out.txt'
  validationBam_path = '/srv/gs1/projects/batzoglou/abishara/data/cancer.validation/bam/l1/merged/merged.sorted.bam'

  print 'creating pileup bed'
  mpileupRegions_map = defaultdict(lambda: ClusterTree(10, 1))
  snpCall_map = {}
  for snpCall in snpCand_set:
    (snpPos, altAllele) = snpCall
    (chrName, pos) = snpPos
    snpCall_map[snpPos] = altAllele
    mpileupRegions_map[chrName].insert(
      pos - 1,
      pos + 1,
      0,
    )

  # dump pileup regions to bed
  #util.writeBedFile(
  #  mpileupRegions_map,
  #  pileupBed_fname,
  #)

  ## launch single pileup job
  #print 'launching pileup'
  #samtools.pileup(
  #  refFasta_path,
  #  pileupBed_fname,
  #  validationBam_path,
  #  pileupOut_path,
  #)

  print 'parsing pileup'
  pileup_map = samtools.parsePileup(pileupOut_path)

  #print 'pileup results', sorted(pileup_map.keys())[:20]
  #print 'snpCands results', sorted(snpCand_set)[:20]
  #die

  saved = 0

  results_map = {}
  numskipped = 0
  for site in snpCall_map.keys():
    pAlt = snpCall_map[site]
    (chrName, pos) = site
    snpCall = (site, pAlt)

    # FIXME the *bam doesn't have chr* prefix, inconsistent
    #pidx = (chrName.strip('chr'), pos)
    pidx = (chrName, pos)
    if pidx not in snpCall_map:
      continue
    #print 'here', pidx
    if pidx not in pileup_map:
      numskipped += 1
      continue
    info_map = pileup_map[pidx]

    results_map[snpCall] = pileup_map[pidx]

  print 'number no pilup results', numskipped
  def cumsum(cnt, base):
    return sum(map(
      lambda(k, v): v,
      filter(
        lambda(k, v): k >= base,
        cnt.items(),
      ),
    ))

  lowqual = 0
  validated = 0
  validated_set = set()
  lowqual_set = set()
  pileup_set = set(results_map.keys())
  error_set = set()
  aaf_list = []
  greyaaf_list = []
  for snpCall, pinfo_map in results_map.items():

    ((chrName, pos), altAllele) = snpCall
    total = sum(pinfo_map['baseCounter'].values())
    mq0total = pinfo_map['allMQCounter'][0]
    if (1. * mq0total / total > 0.5):
      lowqual += 1
      #lowqual_set.add(snpCall)

    suptotal = cumsum(pinfo_map['baseMQCounter'][altAllele], 1)
    aaf = (1. * suptotal / (total - mq0total))
    if aaf > 0.10:
      validated += 1
      validated_set.add(snpCall)
      aaf_list.append(aaf)
    elif aaf < 0.02:
      error_set.add(snpCall)
    else:
      greyaaf_list.append(aaf)

  print 'number lowqual', lowqual
  print 'our validate lowqual', len(validateSnp_set & lowqual_set)
  print 'number validated', validated
  print 'pass all rate', 1. * len(validated_set - lowqual_set) / len(pileup_set - lowqual_set)
  print 'pass filt rate', 1. * len((validateSnp_set & validated_set) - lowqual_set) / len((validateSnp_set & pileup_set) - lowqual_set)
  print 'pass error rate', 1. * len((validateSnp_set & error_set) - lowqual_set) / len((validateSnp_set & pileup_set) - lowqual_set)
  print 'number of out ', len((validateSnp_set & pileup_set) - lowqual_set)
  print 'number pass ', len((validateSnp_set & validated_set) - lowqual_set)
  print 'number error ', len((validateSnp_set & error_set) - lowqual_set)
  print 'number error + lowqual ', len((validateSnp_set & error_set & lowqual_set) )

  print 'these are incorrect', list(error_set)
  print 'these are correct', list((validateSnp_set & validated_set))[:30]

  plotter.plotHistograms(
    [
      ('validated', aaf_list),
      ('uncertain', greyaaf_list),
    ],
    {
      'xlabel' : 'aaf',
      'ylabel' : 'freq',
      'fname'  : 'aaf_plot',
      'numBins': 20,
    }
  )

#--------------------------------------------------------------------------
# count pileups of reads
#--------------------------------------------------------------------------
def pileups():
  pass

#buildRef()
cache_fname = 'shit/cache.p'
if not os.path.isfile(cache_fname):
  snpAnnot_map = getAnnot()
  util.writePickle(cache_fname, snpAnnot_map)
else:
  snpAnnot_map = util.loadPickle(cache_fname)

filterCalls(snpAnnot_map)

