import logging
import random
import numpy as np
from collections import defaultdict, Counter
from bx.intervals.cluster import ClusterTree
from scipy.misc import logsumexp

from ..bam.read import getReadQualityInfo
from ..common import plotter
from ..common import util

logger = logging.getLogger(__name__)

#-------------------------------------------------------------------------
# generate overall results
#-------------------------------------------------------------------------
def getRidsFromCids(cid_set, mappings): 
  return set(map(
    lambda(rid, loc): rid,
    filter(
      lambda(rid, loc): loc.cid in cid_set,
      mappings.items(),
    )
  ))

def getResults(
  aligner,
  pass1Bam_path,
  pass2Bam_path,
  rfaMapq_map,
  oracleMapq_map,
):

  # rfa results
  rfaMappings = aligner.finalMappings
  rfaCloudFeats_map = aligner.extractCloudFeats(rfaMappings)

  # helpers
  #-----------------
  def getCorrectRids(mappings):
    return set(map(
      lambda(rid, loc): rid,
      filter(
        lambda(rid, loc): loc.isTrue,
        mappings.items()
      )
    ))

  pass1ReadInfo_map = getReadQualityInfo(pass1Bam_path)
  pass2ReadInfo_map = getReadQualityInfo(pass2Bam_path)
  #pass1ReadInfo_map = {}
  #pass2ReadInfo_map = {}
  assert len(pass1ReadInfo_map) == len(pass2ReadInfo_map)

  # compute correct sets for pass1
  #-----------------
  pass1MMQRid_set = set()
  pass1HMQRid_set = set()
  pass1CorrectRid_set = set()
  for (rid, loc_set) in pass1ReadInfo_map.items():
    # skip unmapped reads
    if loc_set == None:
      continue
    # chooose an optimal location at random
    minED = sorted(loc_set)[0][0]
    optLoc_list = filter(
      lambda(x): x[0] == minED,
      loc_set,
    )
    (ED, mapq, correct) = random.choice(optLoc_list)
    assert ED == minED
    if correct:
      pass1CorrectRid_set.add(rid)
    if mapq >= 10:
      pass1MMQRid_set.add(rid)
    if mapq >= 20:
      pass1HMQRid_set.add(rid)

  # compute correctness for pass2 - resolve reads
  # and for topAS reads
  #-----------------
  topASMMQRid_set = set()
  topASHMQRid_set = set()

  pass2MMQRid_set = set()
  pass2HMQRid_set = set()
  pass2CorrectRid_set = set()
  for (rid, loc_set) in pass2ReadInfo_map.items():
    # skip unmapped reads
    if loc_set == None:
      continue

    # chooose an optimal location at random
    minED = sorted(loc_set)[0][0]
    optLoc_list = filter(
      lambda(x): x[0] == minED,
      loc_set,
    )
    (ED, mapq, correct) = random.choice(optLoc_list)
    assert ED == minED

    # save topAS mapq
    if rid in aligner.candRid_set:
      if mapq >= 10:
        topASMMQRid_set.add(rid)
      if mapq >= 20:
        topASHMQRid_set.add(rid)

    # for pass2 baseline skip reads touched by RFA
    else:
      if correct:
        pass2CorrectRid_set.add(rid)
      if mapq >= 10:
        pass2MMQRid_set.add(rid)
      if mapq >= 20:
        pass2HMQRid_set.add(rid)

  # compute mapq sets for rfa
  #-----------------
  rfaZeroMQRid_set = set()
  rfaMedMQRid_set = set()
  rfaHighMQRid_set = set()
  for (rid, mapq) in rfaMapq_map.items():
    if mapq == 0:
      rfaZeroMQRid_set.add(rid)
    if mapq >= 20:
      rfaHighMQRid_set.add(rid)
    if mapq >= 10:
      rfaMedMQRid_set.add(rid)

  # compute mapq sets for oracle
  #-----------------
  oracleZeroMQRid_set = set()
  oracleMedMQRid_set = set()
  oracleHighMQRid_set = set()
  for (rid, mapq) in oracleMapq_map.items():
    if mapq == 0:
      oracleZeroMQRid_set.add(rid)
    if mapq >= 20:
      oracleHighMQRid_set.add(rid)
    if mapq >= 10:
      oracleMedMQRid_set.add(rid)

  # compute overall results

  fixedIncorrectRid_set = set()
  for cid in aligner.allCid_set:
    cloud = aligner.cloud_map[cid]
    fixedIncorrectRid_set |= set(map(
      lambda(loc): loc.rid,
      filter(
        lambda(loc): not loc.isTrue,
        cloud.getFixedMappings(),
      )
    ))

  # topAS
  #-----------------
  topASMappings = aligner.getTopASMappings()
  topASExcludeCid_set = aligner.getExcludeCids(topASMappings)
  topASExcludeRid_set = getRidsFromCids(topASExcludeCid_set, topASMappings)
  topASCorrectRid_set = getCorrectRids(topASMappings)
  topASIncorrectRid_set = aligner.candRid_set - topASCorrectRid_set

  # oracle
  #-----------------
  oracleMappings = aligner.getOracleMappings()
  oracleExcludeCid_set = aligner.getExcludeCids(oracleMappings)
  oracleExcludeRid_set = getRidsFromCids(oracleExcludeCid_set, oracleMappings)
  oracleAllCorrectRid_set = getCorrectRids(oracleMappings)
  oracleCorrectRid_set = oracleAllCorrectRid_set - oracleExcludeRid_set
  oracleIncorrectRid_set = aligner.candRid_set - oracleCorrectRid_set

  # rfa
  #-----------------
  (rfaHomoTarget_map, rfaAmbigCid_set) = aligner.getHomoCloudTargets(rfaMappings)
  rfaExcludeCid_set = aligner.getExcludeCids(rfaMappings)
  rfaExcludeRid_set = getRidsFromCids(rfaExcludeCid_set, rfaMappings)
  rfaAllCorrectRid_set = getCorrectRids(rfaMappings)
  rfaCorrectRid_set = rfaAllCorrectRid_set - rfaExcludeRid_set
  rfaAllIncorrectRid_set = aligner.candRid_set - rfaAllCorrectRid_set
  rfaIncorrectRid_set = rfaAllIncorrectRid_set - rfaExcludeRid_set
  rfaIncorrectHmqRid_set = rfaIncorrectRid_set & rfaHighMQRid_set

  rfaHomoCid_set = set(rfaHomoTarget_map.keys())
  rfaActiveCid_set = aligner.getActiveCids(rfaCloudFeats_map)
  rfaAmbigCloudRid_set = getRidsFromCids(rfaAmbigCid_set, rfaMappings)

  rfaFakeCid_set = rfaActiveCid_set & aligner.falseCid_set
  rfaFakeCloudRid_set = getRidsFromCids(rfaFakeCid_set, rfaMappings)

  results_map = {

    'clouds_true' : len(aligner.trueCid_set),
    'clouds_fake' : len(aligner.falseCid_set),

    'rreads_rfa_exclude'               : len(rfaExcludeRid_set),
    'rreads_rfa_exclude_correct'       : len(rfaExcludeRid_set & rfaAllCorrectRid_set),
    'rreads_rfa_exclude_correct_hmq'   : len(rfaExcludeRid_set & rfaAllCorrectRid_set & rfaHighMQRid_set),
    'rreads_rfa_exclude_correct_mmq'   : len(rfaExcludeRid_set & rfaAllCorrectRid_set & rfaMedMQRid_set),
    'rreads_rfa_exclude_incorrect'     : len(rfaExcludeRid_set & rfaAllIncorrectRid_set),
    'rreads_rfa_exclude_incorrect_hmq' : len(rfaExcludeRid_set & rfaAllIncorrectRid_set & rfaHighMQRid_set),
    'rreads_rfa_all_hmq_inambig'    : len(rfaAmbigCloudRid_set & rfaHighMQRid_set),
    'rreads_rfa_resolve_hmq_inambig': len(rfaAmbigCloudRid_set & aligner.resolveRid_set),
    'rreads_rfa_correct'            : len(rfaCorrectRid_set),
    'rreads_rfa_correct_0mq'        : len(rfaCorrectRid_set & rfaZeroMQRid_set),
    'rreads_rfa_correct_mmq'        : len(rfaCorrectRid_set & rfaMedMQRid_set),
    'rreads_rfa_correct_hmq'        : len(rfaCorrectRid_set & rfaHighMQRid_set),
    'rreads_rfa_correct_hmq_inambig': len(rfaCorrectRid_set & rfaHighMQRid_set & rfaAmbigCloudRid_set),
    'rreads_rfa_incorrect_hmq'      : len(rfaIncorrectHmqRid_set),
    'rreads_rfa_incorrect_hmq_inambig'    : len(rfaIncorrectHmqRid_set & rfaAmbigCloudRid_set),
    'rreads_rfa_incorrect_hmq_notruecand' : len(rfaIncorrectHmqRid_set - aligner.resolveRid_set),
    'rreads_rfa_incorrect_hmq_notruecand_infake' : \
      len((rfaIncorrectHmqRid_set - aligner.resolveRid_set) & rfaFakeCloudRid_set),
    'rreads_rfa_incorrect_hmq_yestruecand' : len(rfaIncorrectHmqRid_set & aligner.resolveRid_set),
    'rreads_rfa_incorrect_hmq_yestruecand_infake' : \
      len((rfaIncorrectHmqRid_set & aligner.resolveRid_set) & rfaFakeCloudRid_set),
    
    'rreads_topAS_correct'        : len(topASCorrectRid_set),
    'rreads_topAS_correct_mmq'    : len(topASCorrectRid_set & topASMMQRid_set),
    'rreads_topAS_correct_hmq'    : len(topASCorrectRid_set & topASHMQRid_set),
    'rreads_topAS_incorrect_hmq'  : len(topASIncorrectRid_set & topASHMQRid_set),

    'rreads_oracle_exclude'               : len(oracleExcludeRid_set),
    'rreads_oracle_exclude_correct'       : len(oracleExcludeRid_set & oracleAllCorrectRid_set),
    'rreads_oracle_exclude_correct_hmq'   : len(oracleExcludeRid_set & oracleAllCorrectRid_set & oracleHighMQRid_set),
    'rreads_oracle_exclude_correct_mmq'   : len(oracleExcludeRid_set & oracleAllCorrectRid_set & oracleMedMQRid_set),
    'rreads_oracle_correct'     : len(oracleCorrectRid_set),
    'rreads_oracle_correct_0mq' : len(oracleCorrectRid_set & oracleZeroMQRid_set),
    'rreads_oracle_correct_mmq' : len(oracleCorrectRid_set & oracleMedMQRid_set),
    'rreads_oracle_correct_hmq' : len(oracleCorrectRid_set & oracleHighMQRid_set),

    'total_fixed_incorrect_reads' : len(fixedIncorrectRid_set),
    'total_resolve_reads'     : len(aligner.resolveRid_set),
    'total_multimapped_reads' : len(aligner.candRid_set),
    'total_rfa_hmq_reads'     : len(rfaHighMQRid_set),

    'cclouds_rfa_correct_homo'  : len(rfaHomoCid_set & aligner.trueCid_set),
    'cclouds_rfa_incorrect_homo': len(rfaHomoCid_set & aligner.falseCid_set),
    'cclouds_rfa_incorrect' : len(rfaActiveCid_set & aligner.falseCid_set),
    'cclouds_rfa_correct_exclude'   : len(rfaActiveCid_set & aligner.trueCid_set & rfaExcludeCid_set),
    'cclouds_rfa_incorrect_exclude' : len(rfaActiveCid_set & aligner.falseCid_set & rfaExcludeCid_set),

    'total_rfa_ambig_clouds'     : len(rfaAmbigCid_set),
    'total_rfa_homo_clouds'      : len(rfaHomoCid_set),
    'total_rfa_exclude_clouds'   : len(rfaExcludeCid_set),

    'aallreads_pass1_correct'      : len(pass1CorrectRid_set), 
    'aallreads_pass1_correct_hmq'  : len(pass1CorrectRid_set & pass1HMQRid_set), 
    'aallreads_pass1_correct_mmq'  : len(pass1CorrectRid_set & pass1MMQRid_set), 
    'aallreads_topAS_correct'      : len(pass2CorrectRid_set | topASCorrectRid_set),
    'aallreads_topAS_correct_hmq'  : len((pass2CorrectRid_set & pass2HMQRid_set) | (topASCorrectRid_set & topASHMQRid_set)),
    'aallreads_topAS_correct_mmq'  : len((pass2CorrectRid_set & pass2MMQRid_set) | (topASCorrectRid_set & topASMMQRid_set)),
    'aallreads_rfa_correct'        : len(pass2CorrectRid_set | rfaCorrectRid_set),
    'aallreads_rfa_correct_hmq'    : len((pass2CorrectRid_set & pass2HMQRid_set) | (rfaCorrectRid_set & rfaHighMQRid_set)),
    'aallreads_rfa_correct_mmq'    : len((pass2CorrectRid_set & pass2MMQRid_set) | (rfaCorrectRid_set & rfaMedMQRid_set)),
    'aallreads_oracle_correct'        : len(pass2CorrectRid_set | oracleCorrectRid_set),
    'aallreads_oracle_correct_hmq'    : len((pass2CorrectRid_set & pass2HMQRid_set) | (oracleCorrectRid_set & oracleHighMQRid_set)),
    'aallreads_oracle_correct_mmq'    : len((pass2CorrectRid_set & pass2MMQRid_set) | (oracleCorrectRid_set & oracleMedMQRid_set)),

    'total_reads' : len(pass1ReadInfo_map),
  }
  
  #ordk_list = [
  #]
  #print 'shit'
  #for k in ordk_list:
  #  print (k, results_map[k])
  print 'shit'
  for k in results_map.keys():
    print (k, results_map[k])
  #die

  return results_map

#-------------------------------------------------------------------------
# get regions of interest 
#-------------------------------------------------------------------------
def getRegions(
  aligner,
  rfaMappings,
  rfaMapq_map,
  fixedCloudInfo_list,
):

  # helpers
  #-----------------
  def updateReadRegions(regions_map, loc, cid):
    regions_map[loc.chrName].insert(
      loc.pos,
      loc.aend,
      cid,
    )

  # classify realigned regions
  #-----------------
  # populate various regions of interest
  # read = read locations
  # cloud = cloud locations
  # ambig = static computations of what is inherently ambiguous
  # {h,l}mq = {high,low}mapq
  # (in)corr = (in)correct
  # {y,n}tc = {yes, no} true read cand
  # {y,n}fc = {yes, no} fake cloud
  regionsInfo_map = {}
  readKeys_list = [
    'read-0mq',
    'read-mmq',
    'read-mmq-corr',
    'read-mmq-incorr',
    'read-mmq-incorr-ytc',
    'read-mmq-incorr-ntc',
    'read-mmq-incorr-ytc-yfc',
    'read-mmq-incorr-ntc-yfc',
  ]
  for rk in readKeys_list:
    regionsInfo_map[rk] = defaultdict(lambda: ClusterTree(20, 1))
  cloudKeys_list = [
    'cloud-active',
    'cloud-active-exclude',
    'cloud-correct',
    'cloud-correct-exclude',
    'cloud-incorrect',
    'cloud-incorrect-exclude',

  ]
  for ck in cloudKeys_list:
    regionsInfo_map[ck] = defaultdict(lambda: ClusterTree(1, 1))

  ri_map = regionsInfo_map
  rfaCloudFeats_map = aligner.extractCloudFeats(rfaMappings)
  rfaActiveCid_set = aligner.getActiveCids(rfaCloudFeats_map)
  rfaExcludeCid_set = aligner.getExcludeCids(rfaMappings)

  (rfaHomoTarget_map, rfaAmbigCid_set) = aligner.getHomoCloudTargets(rfaMappings)
  rfaHomoCid_set = set(rfaHomoTarget_map.keys())

  for cid in rfaActiveCid_set:
    cloud = aligner.cloud_map[cid]
    cloudFeats = rfaCloudFeats_map[cid]
    falseCloud = True not in set(map(
      lambda(loc): loc.isTrue,
      cloud.getAllMappings(),
    ))

    # skip excluded clouds for read data
    if cid not in rfaExcludeCid_set:
      activeMappings = cloud.getActiveMappings(rfaMappings)
      for loc in activeMappings:
        if loc.rid in aligner.candRid_set:
          mapq = rfaMapq_map[loc.rid]
        else:
          mapq = 60

        if mapq == 0:
          updateReadRegions(ri_map['read-0mq'], loc, cid)
        if mapq >= 10:
          updateReadRegions(ri_map['read-mmq'], loc, cid)
          if loc.isTrue:
            updateReadRegions(ri_map['read-mmq-corr'], loc, cid)
          else:
            updateReadRegions(ri_map['read-mmq-incorr'], loc, cid)
            if loc.rid in aligner.resolveRid_set:
              updateReadRegions(ri_map['read-mmq-incorr-ytc'], loc, cid)
              if falseCloud:
                updateReadRegions(ri_map['read-mmq-incorr-ytc-yfc'], loc, cid)
            else:
              updateReadRegions(ri_map['read-mmq-incorr-ntc'], loc, cid)
              if falseCloud:
                updateReadRegions(ri_map['read-mmq-incorr-ntc-yfc'], loc, cid)

    ri_map['cloud-active'][cloud.chrName].insert(
      cloudFeats.begin,
      cloudFeats.end,
      cid,
    )
    if cid in rfaExcludeCid_set:
      ri_map['cloud-active-exclude'][cloud.chrName].insert(
        cloudFeats.begin,
        cloudFeats.end,
        cid,
      )

    if cid in aligner.trueCid_set:
      ri_map['cloud-correct'][cloud.chrName].insert(
        cloudFeats.begin,
        cloudFeats.end,
        cid,
      )
      if cid in rfaExcludeCid_set:
        ri_map['cloud-correct-exclude'][cloud.chrName].insert(
          cloudFeats.begin,
          cloudFeats.end,
          cid,
        )
    else:
      ri_map['cloud-incorrect'][cloud.chrName].insert(
        cloudFeats.begin,
        cloudFeats.end,
        cid,
      )
      if cid in rfaExcludeCid_set:
        ri_map['cloud-incorrect-exclude'][cloud.chrName].insert(
          cloudFeats.begin,
          cloudFeats.end,
          cid,
        )

  # update active cloud regions with fixed clouds
  for info in fixedCloudInfo_list:
    (
      cid,
      (chrName, begin, end),
      _1,
      _2,
      _3,
    ) = info
    ri_map['cloud-active'][chrName].insert(
      begin,
      end,
      cid,
    )
    ri_map['read-mmq'][chrName].insert(
      begin,
      end,
      cid,
    )

  return regionsInfo_map

#-------------------------------------------------------------------------
# get debug info
#-------------------------------------------------------------------------
def getDebug(
  aligner,
  rfaMappings,
  rfaMapq_map,
  rfaCloudQuality_map,
  rfaCloudMapq_map,
  fixedCloudInfo_list,
):

  # evaluate incorrect alignments in active clouds
  #-----------------
  rfaCloudFeats_map = aligner.extractCloudFeats(rfaMappings)
  rfaActiveCid_set = aligner.getActiveCids(rfaCloudFeats_map)

  inactiveTrueCid_set = filter(
    lambda(cid): cid in aligner.trueCid_set,
    aligner.allCid_set - rfaActiveCid_set,
  )
  cloudInfo_list = []

  (rfaHomoTarget_map, rfaAmbigCid_set) = aligner.getHomoCloudTargets(rfaMappings)
  rfaHomoCid_set = set(rfaHomoTarget_map.keys())
  rfaExcludeCid_set = aligner.getExcludeCids(rfaMappings)

  def getFeatsTuple(cf):
    # extract confident edit sites
    confEdits_list = []

    edits_list = []
    for (pos, baseCounter) in cf.edits_map.items():
      total = sum(baseCounter.values())
      numRef = baseCounter['ref']
      numAlt = total - numRef
      (logP_v, logP_vp) = aligner.computeSiteLogPs(numRef, numAlt)
      logZ = logsumexp([logP_v, logP_vp])
      logP_v = logP_v - logZ
      P_v = np.exp(logP_v)
      entry = (
        P_v,
        numRef,
        numAlt,
        pos,
      )
      if P_v > 0.5:
        confEdits_list.append(entry)
      if P_v > 0:
        edits_list.append(entry)

    return (
      aligner.computeCloudLogP(cf),
      aligner.computeEditsLogP(cf.edits_map),
      cf.size,
      cf.numReads,
      cf.numFixed,
      cf.numCand,
      cf.lbarValid,
      cf.rbarValid,
      cf.numMixedBar,
      len(edits_list),
      confEdits_list,
    )

  for info in fixedCloudInfo_list:
    (cid, (chrName, begin, end), rid_set, cloudLogP, _) = info
    coords_str = '{0}:{1}-{2}'.format(
      chrName,
      begin,
      end,
    )
    cloudInfo_map = {
      'cid' : cid,
      'coords_str' : coords_str,
      'fixed' : True,
      'active' : True,
      # FIXME tag needs to be implemented
      'mixed' : False,
      'cloudLogP' : cloudLogP,
      'numReads' : len(rid_set),
    }
    cloudInfo_list.append(cloudInfo_map)

  for cid in aligner.allCid_set:
    cloud = aligner.cloud_map[cid]
    cloudFeats = rfaCloudFeats_map[cid]
    staticCloudFeats = cloud.getFeats()
    activeMappings = cloud.getActiveCands(rfaMappings)

    oracleTargetCounter = Counter()
    for loc in activeMappings:
      if loc.rid in aligner.resolveRid_set:
        oracleTargetCounter[aligner.oracleMappings[loc.rid].cid] += 1

    targetMove_map = aligner.getSubmoves(cid, activeMappings)
    targetCounter = Counter(dict(map(
      lambda(k, v): (k, len(v)),
      targetMove_map.items(),
    )))
    
    # classify into
    # - no valid target
    # - single active target
    # - single inactive target
    # - multiple targets
    ftype = 'nofake'
    if cid in aligner.falseCid_set:
      if len(oracleTargetCounter) == 0:
        ftype = '0-target'
      elif len(oracleTargetCounter) == 1:
        targetCid = oracleTargetCounter.keys()[0]
        if rfaCloudFeats_map[targetCid].isActive:
          ftype = '1-target-active'
        else:
          ftype = '1-target-inactive'
      else:
        assert len(oracleTargetCounter) > 1
        ftype = '1+-target'

    # for each homocloud get the features of all possible targets
    homoTargetFeats_list = []
    if cid in rfaHomoCid_set:
      homoTargetFeats_list = map(
        lambda(tcid, cf, tlogp): (
          tcid, 
          None if cf == None else getFeatsTuple(cf),
          tlogp,
        ),
        rfaHomoTarget_map[cid],
      )

    # for each incorrect read save read scope features
    incorrReadFeats_list = []
    numCorrectHq = 0
    numCorrectMq = 0
    for loc in activeMappings:
      mapq = rfaMapq_map[loc.rid]
      if loc.isTrue:
        numCorrectMq += (mapq >= 10)
        numCorrectHq += (mapq >= 30)
        continue
      trueCandGen = (loc.rid in aligner.resolveRid_set)
      scopeFeats_list = sorted(map(
        lambda(loc2): (
          loc2.ED,
          loc2.cid,
          loc2.isTrue,
          loc2.cid in rfaActiveCid_set,
          loc2 == loc,
        ),
        aligner.readScope_map[loc.rid],
      ))
      incorrReadFeats_list.append(
        (trueCandGen, mapq, loc.truePos, scopeFeats_list)
      )

    coords_str = '{0}:{1}-{2}'.format(
      cloud.chrName,
      staticCloudFeats.begin,
      staticCloudFeats.end,
    )

    cloudInfo_map = {
      'cid' : cid,
      'fixed' : False,
      'active' : cid in rfaActiveCid_set,
      'mixed' : (cloudFeats.numMixedBar > 0),
      'ambig' : cid in rfaAmbigCid_set,
      'fake'  : cid in aligner.falseCid_set,
      'homo'  : cid in rfaHomoCid_set,
      'exclude' : cid in rfaExcludeCid_set,
      'ftype' : ftype,
      'coords_str' : coords_str,
      'numCorrectHq' : numCorrectHq,
      'numCorrectMq' : numCorrectMq,
      'srcFeats' : getFeatsTuple(cloudFeats),
      'targetCount_list' : targetCounter.most_common(),
      'oracleTargetCount_list' : oracleTargetCounter.most_common(),
      'homoTargetFeats_list' : homoTargetFeats_list,
      'incorrReadFeats_list': incorrReadFeats_list,
      'cloudLogP' : rfaCloudQuality_map[cid],
      'cloudmapq' : rfaCloudMapq_map[cid],
    }
    cloudInfo_list.append(cloudInfo_map)

  return cloudInfo_list

