import os
import sys
import random
import math
import numpy as np
import json
import logging
from collections import Counter, namedtuple, defaultdict
from itertools import combinations, izip, product
from scipy.misc import logsumexp
from bx.intervals.intersection import IntervalTree
import copy

from ..common import util
from ..common.math import EmpiricalDist

from .candcloud import CandCloud, getFeatsStr
from .models import CloudModel

logger = logging.getLogger(__name__)

#=========================================================================
# aligner inference
#=========================================================================
class InferenceEngine(object):
  # edit logP cache 
  editLogPCache_map = {}

  def __init__(
    self,
    params_map,
    candCloud_list,
    simulate=False,
    **kwargs
  ):

    opts = {}
    opts.update(kwargs)

    # simulation mode
    self.simulate = simulate
    # map of { <cid> : CandCloud object } 
    self.cloud_map = {}

    # possible mapping locations of reads
    # map of <rid> : set(homos)
    self.readScope_map = defaultdict(set) 
    # set of all mappings
    self.allMappings = set()
    # map of { <loc> : score }
    self.readCandScore_map = {}
    # set of all candidate rids to disambiguate between
    self.candRid_set = set()
    self.candQid_set = set()
    # set of all rids which have more than one optimal mapping
    self.ambigRid_set = set()

    # golden set map of { <rid> : loc }
    self.oracleMappings = {}
    # set of all rids which are resolveable given the set of all
    # candidates
    self.resolveRid_set = set()

    # set of all true clouds
    self.trueCid_set = set()
    # set of all bogus clouds
    self.falseCid_set = set()
    # set of all cloud IDs including fixed clouds
    self.allCid_set = set()

    # move cache to not recompute scores for already examined moves
    self.moveid = 0
    # <(cid, tcid)>: moveidx
    self.lastEval_map = defaultdict(lambda:int(-1))
    # <cid>: moveidx
    self.lastMove_map = defaultdict(lambda:int(0))

    # final set of mappings from aligner
    self.finalMappings = None

    # load cloud model
    self.cloudModel = CloudModel.load(params_map['cloudModel'])
    # read pairing frequencies
    # FIXME this is estimated from the short reads and is the same for
    # pretty much all sequencing runs
    #  - factor out to be estimated from passing over wells
    self.pairCpt = {
      # only one read is mapped 
      '1m': 0.017418200641364043, 
      # both reads in pair mapped but to do different clodus
      '2m_2c': 0.0025782731882582587,
      # both reads maped to the same cloud
      '2m_1c': 0.9800035261703777,
    }

    # build in memory structures needed throughout the run
    self.__build__(candCloud_list)

#-------------------------------------------------------------------------
# build in memory tables to compute MRF assignment updates
#-------------------------------------------------------------------------
  def __build__(self, candCloud_list):

    # build read scope map
    logger.info('building initial in memory structures')

    # build read scopes
    #-------------
    logger.info('  - build read scopes')
    for cloud in candCloud_list:
      self.cloud_map[cloud.cid] = cloud
      for loc in cloud.getCandMappings():
        self.readScope_map[loc.rid].add(loc)
        self.allMappings.add(loc)
        # save oracle information
        if loc.isTrue:
          self.oracleMappings[loc.rid] = loc
          self.resolveRid_set.add(loc.rid)
      trueReadPresent = True in set(map(
        lambda(loc): loc.isTrue,
        cloud.getAllMappings(),
      ))
      self.allCid_set.add(cloud.cid)
      if trueReadPresent:
        self.trueCid_set.add(cloud.cid)
      else:
        self.falseCid_set.add(cloud.cid)

    # freeze domains
    self.readScope_map.default_factory = None
    # assert read scope contains only multimapped
    for (rid, loc_set) in self.readScope_map.items():
      assert len(loc_set) > 0

    self.candRid_set = set(self.readScope_map.keys())
    self.candQid_set = set(map(
      lambda(rid): rid[0],
      self.candRid_set,
    ))

    self.ambigRid_set = set(filter(
      lambda(rid): self.isAmbigRead(rid),
      self.candRid_set,
    ))

    logger.info('  - {0} cands, of which {1} can be resolved'.format(
      len(self.candRid_set),
      len(self.resolveRid_set),
    ))
    logger.info('({0}, {1}) = (true, false) candidate clouds detected'.format(
      len(self.trueCid_set),
      len(self.falseCid_set),
    ))

    logger.info('  - finished build')

#-------------------------------------------------------------------------
# get (in)active cloud IDs
#-------------------------------------------------------------------------
  def getActiveCids(self, cloudFeats_map):
    return set(
      filter(
        lambda(cid): cloudFeats_map[cid].isActive,
        self.allCid_set,
      )
    )

  def getInactiveCids(self, cloudFeats_map):
    return set(filter(
      lambda(cid): not cloudFeats_map[cid].isActive,
      self.allCid_set,
    ))

#-------------------------------------------------------------------------
# obtain initial mappings
#-------------------------------------------------------------------------
  def getRandInitReads(self):
    mappings = {}
    for (rid, loc_set) in self.readScope_map.items():
      choice = random.sample(loc_set, 1)[0]
      mappings[rid] = choice
    return mappings
    
  def getTopASMappings(self):
    # first determine active and inactive clouds if we just follow topAS
    mappings = {}
    for (rid, loc_set) in self.readScope_map.items():
      assert (len(loc_set) > 0)
      choice = max(loc_set, key=lambda(loc): loc.AS)
      mappings[rid] = choice
    return mappings

  def getOracleMappings(self):
    return dict(self.oracleMappings)

  def filterAmbigMappings(self, mappings):

    cloudFeats_map = self.extractCloudFeats(mappings)
    activeCid_set = self.getActiveCids(cloudFeats_map)

    ambigRid_set = set()
    for rid in mappings:
      filtloc_list = filter(
        lambda(loc): loc.cid in activeCid_set,
        self.readScope_map[rid],
      )
      # ambiguous if no active mapping
      if len(filtloc_list) == 0:
        ambigRid_set.add(rid)
      else:
        (topAS, topCount) = sorted(
          Counter(map(lambda(loc): loc.AS, filtloc_list)).items(),
          reverse=True,
        )[0]
        # ambiguous if more than one top live score or active topAS not
        # chosen
        if topCount > 1 or mappings[rid].AS != topAS :
          ambigRid_set.add(rid)

    # update mappings to remove ambiguous reads
    filtmappings = dict(filter(
      lambda(k, v): k not in ambigRid_set,
      mappings.items(),
    ))
    return filtmappings

#-------------------------------------------------------------------------
# compute scores
#-------------------------------------------------------------------------
  @staticmethod
  def computeSiteLogPs(
    numRef,
    numAlt,
    logP_v=None,
    logP_vp=None,
  ):
    prob_map = {
      "logP(v)"  : math.log(0.001),
      "logP(v')" : math.log(1 - 0.001),
      "logP(err)"  : math.log(0.01),
      "logP(err')" : math.log(1 - 0.01),
    }
    assert (
      (logP_v == None and logP_vp == None) or
      (logP_v != None and logP_vp != None)
    )
    if logP_v != None:
      prob_map["logP(v)"]  = logP_v
      prob_map["logP(v')"] = logP_vp

    # P(r, v)
    logP_v = (
      prob_map['logP(v)'] + 
      prob_map['logP(err)'] * numRef +
      prob_map["logP(err')"] * numAlt
    )
    # P(r, v')
    logP_vp = (
      prob_map["logP(v')"] + 
      prob_map['logP(err)'] * numAlt + 
      prob_map["logP(err')"] * numRef
    )

    return (logP_v, logP_vp)

  @staticmethod
  def computeEditsLogP(edits_map):
    # score edits
    # all positions treated as independent
    logP = 0.0
    for (pos, baseCounter) in edits_map.items():
      total = sum(baseCounter.values())
      numRef = baseCounter['ref']
      numAlt = total - numRef
      # do not counts sites with no alternate allele actively displayed
      if numAlt == 0:
        continue
      k = (numRef, numAlt)
      if k in InferenceEngine.editLogPCache_map:
        _logP = InferenceEngine.editLogPCache_map[k]
        logP += _logP
        continue

      # do updating
      # accumulate with logsumexp
      _logP = logsumexp(InferenceEngine.computeSiteLogPs(numRef, numAlt))
      InferenceEngine.editLogPCache_map[k] = _logP
      logP += _logP
    return logP

  # compute predictive posterior with observed {ref, alt} counts
  @staticmethod
  def computePostPredLogP(numRef, numAlt):
    assert numRef >= 0 and numAlt >= 0
    (logP_v, logP_vp) = InferenceEngine.computeSiteLogPs(numRef, numAlt)
    logZ = logsumexp([logP_v, logP_vp])
    logP_v = logP_v - logZ
    logP_vp = logP_vp - logZ
    return (logP_v, logP_vp)

  # cloud
  def computeCloudLogP(self, cloudFeats):
    if cloudFeats.numReads == 0:
      return 0

    logP = 0.0
    # score edits
    logP += InferenceEngine.computeEditsLogP(cloudFeats.edits_map)
    # score other cloud features
    logP += math.log(self.cloudModel.getProb(cloudFeats))

    return logP
    
  def computeCloudScore(self, cloudFeats):
    return 1. * self.computeCloudLogP(cloudFeats)

  # pair
  def getPairType(self, loc, mateLoc):
    return '1m' if mateLoc == None else \
      '2m_2c' if mateLoc.cid != loc.cid else '2m_1c'

  def computePairProb(self, pairType):
    return self.pairCpt[pairType]

  def computePairScore(self, pairType):
    logP = math.log(self.pairCpt[pairType])
    return logP

#-------------------------------------------------------------------------
# inference procedure step
#-------------------------------------------------------------------------

  # get alignments which have fewer edits in the target cloud
  def getImproved(self, loc_set, tcid):

    improved_set = set()
    for loc in loc_set:
      rid = loc.rid
      for tloc in sorted(
        self.readScope_map[rid],
        key=lambda(l):l.AS,
        reverse=True,
      ):
        if tloc.cid == tcid and tloc.AS > loc.AS:
          print 'improved!'
          print '  - sloc', str(loc)
          print '  - tloc', str(tloc)
          improved_set.add(tloc)
          break
    return improved_set
     

  # get all submoves for this set of active mappings
  def getSubmoves(
    self,
    srcCid,
    activeMappings,
  ):
    targetMove_map = defaultdict(set)
    targetCloudCounter = Counter()
    for loc in activeMappings:
      rid = loc.rid
      seenCid_set = set()
      # traverse possible targets in reverse order of alignment score so
      # if there are multiple mappings to target cloud will pick the one
      # with highest alignment score
      for tloc in sorted(
        self.readScope_map[rid],
        key=lambda(l):l.AS,
        reverse=True,
      ):
        if tloc.cid in seenCid_set:
          continue
        seenCid_set.add(tloc.cid)
        targetCloudCounter[tloc.cid] += 1
        targetMove_map[tloc.cid].add(tloc)

    # filter for largest moves
    targetCid_set = set(map(
      lambda(tcid, _): tcid,
      targetCloudCounter.most_common()[:20],
    ))
    # NOTE can skip evaluating moves if nothing has changed since the last
    # time we evaluated
    targetMove_map = dict(filter(
      lambda(tcid, _): (
        tcid == srcCid or (
          tcid in targetCid_set and 
          max(
            self.lastMove_map[srcCid],
            self.lastMove_map[tcid],
          ) > self.lastEval_map[(srcCid, tcid)]
        )
      ),
      targetMove_map.items(),
    ))
    if srcCid in targetMove_map:
      del targetMove_map[srcCid]
    return targetMove_map

  # evaluate move for a particular source cloud
  def moveCloud(
    self,
    srcCid,
    prevFieldScore,
    prevmappings,
    cloudFeats_map,
    pairFeats_map,
  ):

    srcCloud = self.cloud_map[srcCid]
    srcCloudFeats = cloudFeats_map[srcCid]
    #logger.debug('* eval move of {0} {1} cid {2}'.format(
    #  'active' if srcCloudFeats.isActive else 'inactive',
    #  'true' if srcCid in self.trueCid_set else 'false',
    #  srcCid,
    #))

    candsScore_map = {}
    candsUpdateCloudFeats_map = {}
    candsUpdatePairFeats_map = {}
    candsUpdateMappings_map = {}

    # get this clouds submoves with the current active mappings
    srcActiveMappings = srcCloud.getActiveCands(prevmappings)
    submove_map = self.getSubmoves(srcCid, srcActiveMappings)

    loggy = False
    #numIncorr = sum(map(
    #  lambda(loc): not loc.isTrue,
    #  srcActiveMappings,
    #))
    #numTrueGen = sum(map(
    #  lambda(loc): loc.rid in self.resolveRid_set,
    #  srcActiveMappings,
    #))
    #logger.debug('  - curr (tot:{0}, incorr:{1}, trugen:{2})'.format(
    #  len(srcActiveMappings),
    #  numIncorr,
    #  numTrueGen,
    #))

    # return previous field configuration if no possible submoves
    if len(submove_map) == 0:
      return  (prevFieldScore, {}, {}, {})

    # score each submove
    for (targetCid, newLoc_set) in submove_map.items():

      numCorr = sum(map(
        lambda(loc): loc.isTrue,
        newLoc_set,
      ))
      #logger.debug('  - eval move to {0} cid {1}'.format(
      #  'true' if targetCid in self.trueCid_set else 'false',
      #  targetCid,
      #))
      #logger.debug('    - (tot:{0}, corr:{1})'.format(
      #  len(newLoc_set),
      #  numCorr,
      #))
      targetCloud = self.cloud_map[targetCid]
      touchedCids = set([
        srcCid,
        targetCid,
      ])

      # get active mappings for target cloud
      targetActiveMappings = targetCloud.getActiveCands(prevmappings)
      activeMappings = targetActiveMappings | srcActiveMappings

      # update mappings to reflect proposed submove
      propmappings = dict(map(
        lambda(loc): (loc.rid, loc),
        activeMappings,
      ))
      for loc in newLoc_set:
        propmappings[loc.rid] = loc

      scoreDelta = 0 
      # cloud score delta
      propCloudFeats_map = {}
      for cid in touchedCids:
        cloud = self.cloud_map[cid]
        propCloudFeats_map[cid] = cloud.getFeats(propmappings)
        prevScore = self.computeCloudScore(cloudFeats_map[cid])
        postScore = self.computeCloudScore(propCloudFeats_map[cid])
        scoreDelta -= prevScore
        scoreDelta += postScore
        #if loggy:
        #  logger.debug('prev cloud {0} score {1} feats {2}'.format(
        #    cid,
        #    prevScore,
        #    getFeatsStr(cloudFeats_map[cid]),
        #  ))
        #  if cloudFeats_map[cid].numReads == 0:
        #    logger.debug('  - empty!')
        #  else:
        #    logger.debug('  - (edits, cloudprop) = ({0}, {1})'.format(
        #      InferenceEngine.computeEditsLogP(cloudFeats_map[cid].edits_map),
        #      math.log(self.cloudModel.getProb(cloudFeats_map[cid])),
        #    ))
        #  logger.debug('post cloud {0} score {1} feats {2}'.format(
        #    cid,
        #    postScore,
        #    getFeatsStr(propCloudFeats_map[cid]),
        #  ))
        #  if propCloudFeats_map[cid].numReads == 0:
        #    logger.debug('  - empty!')
        #  else:
        #    logger.debug('  - (edits, cloudprop) = ({0}, {1})'.format(
        #      InferenceEngine.computeEditsLogP(propCloudFeats_map[cid].edits_map),
        #      math.log(self.cloudModel.getProb(propCloudFeats_map[cid])),
        #    ))

      # read and pair score delta
      prevPairScore = 0
      postPairScore = 0
      seenQname_set = set()
      propPairFeats_map = {}
      for loc in newLoc_set:
        prevLoc = prevmappings[loc.rid]

        if loc.qname in seenQname_set:
          continue
        seenQname_set.add(loc.qname)

        (_, fwd) = loc.rid
        mateRid = (loc.qname, not fwd)
        mateLoc = None if mateRid not in propmappings else propmappings[mateRid]
        propPairType = self.getPairType(loc, mateLoc)
        propPairFeats_map[loc.qname] = propPairType
        scoreDelta -= self.computePairScore(pairFeats_map[loc.qname])
        scoreDelta += self.computePairScore(propPairType)

        prevPairScore += self.computePairScore(pairFeats_map[loc.qname])
        postPairScore += self.computePairScore(propPairType)

      #if loggy:
      #  logger.debug('prev pair score: {0}'.format(prevPairScore))
      #  logger.debug('post pair score: {0}'.format(postPairScore))
      #  logger.debug('score delta: {0}\n\n'.format(scoreDelta))

      #print '**** score delta: {0}'.format(scoreDelta)

      candsScore_map[targetCid] = scoreDelta
      candsUpdatePairFeats_map[targetCid] = propPairFeats_map
      candsUpdateCloudFeats_map[targetCid] = propCloudFeats_map
      candsUpdateMappings_map[targetCid] = propmappings

      self.lastEval_map[(srcCid, targetCid)] = self.moveid

    # choose the maximal submove
    (destCid, scoreDelta) = max(candsScore_map.items(), key=lambda(s): s[1])
    updatedMappings_map = candsUpdateMappings_map[destCid]
    updatedPairFeats_map = candsUpdatePairFeats_map[destCid]
    updatedCloudFeats_map = candsUpdateCloudFeats_map[destCid]
    newFieldScore = scoreDelta + prevFieldScore

    # debug
    destCloud = self.cloud_map[destCid]
    if (srcCid in self.trueCid_set and destCid not in self.trueCid_set):
      moveType = 'bad'
    elif (srcCid not in self.trueCid_set and destCid in self.trueCid_set):
      moveType = 'good'
    else:
      moveType = 'neutral'
    #logger.debug('  - {0} move to cid {1}\n\n'.format(moveType, destCid))

    if destCid != srcCid:
      self.moveid += 1
      self.lastMove_map[destCid] = self.moveid
      self.lastMove_map[srcCid] = self.moveid

    # update mappings to move this cloud
    return (
      newFieldScore, 
      updatedMappings_map, 
      updatedPairFeats_map,
      updatedCloudFeats_map,
    )

#-------------------------------------------------------------------------
# evaluate current set of mappings
#-------------------------------------------------------------------------
  def evaluate(self, mappings, cloudFeats_map):
    logger.info('evaluate mappings')

    def uniqueMax(l):
      if len(l) == 0:
        return False
      maxy = max(l)
      cnt = Counter(l)
      assert cnt[maxy] >= 1
      return (cnt[maxy] == 1)

    activeCid_set = self.getActiveCids(cloudFeats_map)

    correctRid_set = set()
    hmqRid_set = set()
    falseCloudRid_set = set()
    for (rid, cloc) in mappings.items():
      if cloc.isTrue:
        correctRid_set.add(rid)

      if cloc.cid in self.falseCid_set:
        falseCloudRid_set.add(rid)

      # only examine locations which are in active clouds to determine
      # "high quality" mappings
      filtAS_list = map(
        lambda(loc): loc.AS,
        filter(
          lambda(loc): loc.cid in activeCid_set,
          self.readScope_map[rid],
        )
      )
      if uniqueMax(filtAS_list):
        hmqRid_set.add(rid)

    numCorrect = 1.* len(correctRid_set)
    pctCorrect = 0 if len(self.resolveRid_set) == 0 else \
      numCorrect / len(self.resolveRid_set)

    activeFalseCid_set = activeCid_set & self.falseCid_set
    incorrectRid_set = self.candRid_set - correctRid_set
    trueCloudRid_set = self.candRid_set - falseCloudRid_set

    logger.info('num correct:{0}'.format(numCorrect))
    logger.info('pct correct (resolveable):{0}'.format(pctCorrect))

    # number of active false clouds
    logger.info('number active false clouds:{0}'.format(
      len(activeFalseCid_set)))

    # number of (hmq) {correct, incorrect} reads in clouds
    logger.info('hmq correct:{0}'.format(len(
      correctRid_set & hmqRid_set)))
    logger.info('hmq incorrect:{0}'.format(len(
      incorrectRid_set & hmqRid_set)))

    # number incorrect hmq no-truecand reads
    logger.info('hmq incorrect no true cand:{0}'.format(len(
      hmqRid_set & (self.candRid_set - self.resolveRid_set))))

    # number incorrect hmq, truecand in false clouds
    logger.info('hmq rslv incorrect in false cloud:{0}'.format(len(
      incorrectRid_set & hmqRid_set & self.resolveRid_set & falseCloudRid_set)))

    # number incorrect hmq, truecand in true clouds
    logger.info('hmq rslv incorrect in true cloud:{0}'.format(len(
      incorrectRid_set & hmqRid_set & self.resolveRid_set & trueCloudRid_set)))

#-------------------------------------------------------------------------
# extract all cloud and pair features for mappings
#-------------------------------------------------------------------------
  def extractCloudFeats(self, mappings):
    cloudFeats_map = {}
    for cid in self.allCid_set:
      cloudFeats_map[cid] = self.cloud_map[cid].getFeats(mappings)
    return cloudFeats_map

  def extractPairFeats(self, mappings):
    mappings_set = set(mappings.values())
    pairFeats_map = {}
    for qid in self.candQid_set:
      r1rid = (qid, True)
      r2rid = (qid, False)
      r1Mappings = set()
      r2Mappings = set()
      if r1rid in self.candRid_set:
        r1Mappings = mappings_set & self.readScope_map[r1rid]
      if r2rid in self.candRid_set:
        r2Mappings = mappings_set & self.readScope_map[r2rid]
      assert max(len(r1Mappings), len(r2Mappings)) <= 1
      if 0 in [len(r1Mappings), len(r2Mappings)]:
        pairFeats_map[qid] = '1m'
      else:
        r1Cid = iter(r1Mappings).next().cid
        r2Cid = iter(r2Mappings).next().cid
        pairFeats_map[qid] = '2m_1c' if r1Cid == r2Cid else '2m_2c'

    return pairFeats_map

#-------------------------------------------------------------------------
# read mapq query computation
#-------------------------------------------------------------------------
  def getMapqNaive(self, mappings, cloudFeats_map):
    # { <rid> : mapq }
    mapq_map = {}
    activeCid_set = self.getActiveCids(cloudFeats_map)
    for (rid, loc) in mappings.items():
      # obtain number of optimal and suboptimal reads in active clouds
      activeLoc_list = filter(
        lambda(loc): loc.cid in activeCid_set,
        self.readScope_map[rid],
      )

      # read may have no candidates in active clouds, count as 0 optimal
      # alignments
      if len(activeLoc_list) == 0:
        numOpt = 0
        numSubOpt = 0
      else:
        EDCounter = Counter(
          map(
            lambda(loc): loc.ED,
            activeLoc_list,
          )
        )
        listy = sorted(EDCounter.items(), reverse=True)
        (optED, numOpt) = listy[0]
        if len(listy) > 1:
          (soptED, numSOpt) = listy[1]
        else:
          (soptED, numSOpt) = (None, None)
        numSubOpt = sum(EDCounter.values()) - numOpt

      mapq = None
      if numOpt == 0:
        mapq = 0
      if numOpt == 1:
        if (
          numSubOpt <= 8 and
          (
            soptED == None or
            soptED > optED + 1
          )
        ):
          mapq = 60
        else:
          mapq = 3

      elif numOpt == 2:
        mapq = 3
      elif numOpt in [3,4]:
        mapq = 1
      else:
        mapq = 0

      mapq_map[rid] = mapq

    return mapq_map

  def __getEditSiteProbs__(
    self,
    mappings,
    cloudQuality_map,
    cloudFeats_map,
    activeCid_set,
    activeReadScope_map,
    eintervals_map,
  ):

    def getRids(loc_set):
      return set(map(
        lambda(loc): loc.rid,
        loc_set,
      ))
    def getCids(loc_list):
      return map(
        lambda(loc): loc.cid,
        loc_list,
      )
    def flatten(l):
      return [item for sublist in l for item in sublist]
    # enumerate alternate mapping vectors for provided reads
    # NOTE random sample if domain is large
    def getAltMappings(inputLoc_set):
      MAX_SIZE = 100
      scid = iter(inputLoc_set).next().cid
      # determine move groups
      tcidRid_map = defaultdict(set)
      ridTcidSet_map = defaultdict(set)
      ridTcidLoc_map = defaultdict(dict)
      fixedRid_set = set()
      for loc in inputLoc_set:
        rid = loc.rid
        # skip fixed reads
        if rid not in activeReadScope_map:
          fixedRid_set.add(rid)
          continue
        multi = False
        for tloc in activeReadScope_map[rid]:
          if tloc == loc:
            continue
          multi = True
          tcid = tloc.cid
          tcidRid_map[tcid].add(rid)
          ridTcidSet_map[rid].add(tcid)
          ridTcidLoc_map[rid][tcid] = tloc
        if not multi:
          fixedRid_set.add(rid)

      ridGroups_list = []
      seenRid_set = set(fixedRid_set)
      while len(seenRid_set) != len(inputLoc_set):
        (tcid, rid_set) = sorted(
          tcidRid_map.items(),
          key=lambda(x): len(x[1]),
          reverse=True,
        )[0]
        ridGroup = rid_set
        ridGroups_list.append(set(ridGroup))
        seenRid_set.update(rid_set)
        for (tcid, rid_set) in tcidRid_map.items():
          rid_set -= seenRid_set

      # add in fixed group
      if len(fixedRid_set) > 0:
        ridGroups_list.append(fixedRid_set)

      assert sum(map(
        lambda(x): len(x),
        ridGroups_list,
      )) == len(inputLoc_set)

      # update domains to include source cloud
      for loc in inputLoc_set:
        ridTcidLoc_map[loc.rid][loc.cid] = loc

      domains_list = []
      domainSize = 1
      for ridGroup in ridGroups_list:
        # don't let single read groups move....
        rootRid = iter(ridGroup).next()
        if len(ridGroup) == 1:
          domains_list.append([[ridTcidLoc_map[rootRid][scid]]])
          continue

        tcidGroup = set(ridTcidSet_map[rootRid])
        for rid in ridGroup:
          tcidGroup &= ridTcidSet_map[rid]
        tcidGroup.add(scid)

        samples_list = []
        for tcid in tcidGroup:
          sample = []
          for rid in ridGroup:
            sample.append(ridTcidLoc_map[rid][tcid])
          samples_list.append(sample)
        domains_list.append(samples_list)
        domainSize *= len(samples_list)

      if domainSize > MAX_SIZE:
        logger.warn('big domain of size {0}'.format(domainSize))

      for (i, sample) in enumerate(product(*domains_list)):
        if i > MAX_SIZE:
          break
        yield flatten(sample)

      raise StopIteration


    # check point inside interval
    def insideInterval(pt, begin, end):
      return (pt >= begin and pt < end)

    # caches for already computed values
    postPredCache_map = {}
    # P(X_l | R_l-) for each X_l
    # {<(chrName, pos)> : (logP_v, logP_vp, numalt, numref)
    editSiteInfo_map = {}

    ## FIXME remove
    #return editSiteInfo_map

    # mappings of distributions of P(R_l | R_l-) for each R_l
    # {<tuple(loc_set)> : [(ploc_set, logP)]
    #readGroupDist_map = defaultdict(dict)

    # pass through all clouds to build P(X_l | R_l-)
    #-------------
    seenSite_set = set()
    step = max(1, len(activeCid_set) / 30)
    logger.info('passing through all active clouds')
    for (i, cid) in enumerate(activeCid_set):

      cloud = self.cloud_map[cid]
      cloudFeats = cloudFeats_map[cid]
      chrName = cloud.chrName

      qual = cloudQuality_map[cid]
      snpRate = qual / math.log(0.001)
      if snpRate > 15. / 1000:
        logger.info('  - skipping low quality cloud: {0}'.format(cid))
        logger.info('{0}:{1}-{2}'.format(
          cloud.chrName, 
          cloudFeats.begin,
          cloudFeats.end))
        logger.info('snp rate:{0}'.format(snpRate * 1000))
        continue
      
      if i % step == 0:
        logger.info('  - {0}, {1}'.format(i, len(activeCid_set)))

      activeMappings = cloud.getActiveMappings(mappings)
      editsReads_map = cloud.getEditsReadsMap(activeMappings)
      #print 'cid', cid
      #print 'this cloud num edits', len(cloudFeats.edits_map)
      for (XeditPos, baseCounter) in cloudFeats.edits_map.items():
        total = sum(baseCounter.values())
        finalNumRef = baseCounter['ref']
        finalNumAlt = total - finalNumRef

        if finalNumAlt == 0:
          continue

        XeditSite = (chrName, XeditPos)
        assert XeditSite not in seenSite_set
        seenSite_set.add(XeditSite)

        # obtain alternate allele
        (altAllele, cnt) = filter(
          lambda(k,v): k != 'ref',
          baseCounter.most_common(),
        )[0]
        assert cnt > 0

        loc_set = editsReads_map[XeditPos]
        #rgid = tuple(sorted(loc_set))

        # first compute P(R_l | R_l-) for all possible vectors R_l
        # NOTE can downsample if space too big
        #-------------
        #if rgid not in readGroupDist_map:
        siteLogP_list = []
        for altMappings in getAltMappings(loc_set):
          ploc_set = set(altMappings)
          pSiteLogP = 0.
          # obtain all candidate edit sites at this location
          pSiteHits = []
          pseenSite_set = set()
          for loc in ploc_set:
            hits = eintervals_map[loc.chrName].find(
              loc.pos,
              loc.aend - 1,
            )
            for hit in hits:
              if hit[0] not in pseenSite_set:
                pSiteHits.append(hit)
              pseenSite_set.add(hit[0])
          # compute predictive posterior P(X|R_l-)
          posPredPost_map = {}
          for (peditSite, rloc_set, aloc_set) in pSiteHits:
            # correct for double counting in MAP assignment
            numRef = len(rloc_set - ploc_set)
            numAlt = len(aloc_set - ploc_set)
            k = (numRef, numAlt)
            if k in postPredCache_map:
              (logP_v, logP_vp) = postPredCache_map[k]
            else:
              (logP_v, logP_vp) = InferenceEngine.computePostPredLogP(numRef, numAlt)
              postPredCache_map[k] = (logP_v, logP_vp)

            posPredPost_map[peditSite] = (logP_v, logP_vp)

          # check that all edits present in R_l are present as site hits
          readEditSite_set = set()
          for loc in ploc_set:
            for pos in loc.edits_map:
              editSite = (loc.chrName, pos)
              readEditSite_set.add(editSite)
          assert readEditSite_set.issubset(set(posPredPost_map.keys()))

          # compute P(R_l = ploc | R_l-)     
          for (peditSite, rloc_set, aloc_set) in pSiteHits:
            numRef = len(rloc_set - ploc_set)
            numAlt = len(aloc_set - ploc_set)
            (pchrName, ppos) = peditSite
            (logP_v, logP_vp) = posPredPost_map[peditSite]
            obsNumRef = 0
            obsNumAlt = 0
            # for each read count obs{ref, alt}
            for loc in ploc_set:
              if loc.chrName == pchrName:
                if ppos in loc.edits_map:
                  obsNumAlt += 1
                elif insideInterval(ppos, loc.pos, loc.aend):
                  obsNumRef += 1

            # skip sites with no edits
            totalReads = (
              #numRef + obsNumRef + 
              numAlt + obsNumAlt
            )
            if totalReads == 0:
              continue

            pSiteLogP += logsumexp(InferenceEngine.computeSiteLogPs(
              obsNumRef,
              obsNumAlt,
              logP_v=logP_v,
              logP_vp=logP_vp,
            ))
          siteLogP_list.append(
            (ploc_set, pSiteLogP)
          )
        # renormalize
        logZ = logsumexp(map(
          lambda(_, logP): logP,
          siteLogP_list,
        ))
        siteLogP_list = map(
          lambda(ploc_set, logP): (ploc_set, logP - logZ),
          siteLogP_list,
        )
        #readGroupDist_map[rgid] = siteLogP_list
        #mappingsDistLogP_list = readGroupDist_map[rgid]

        mappingsDistLogP_list = siteLogP_list

        # compute P(X_l | R_l-) using P(R_l | R_l-)
        #-------------
        siteLogP_v_list = []
        siteLogP_vp_list = []

        (chrName, pos) = XeditSite
        for (ploc_set, readLogP) in mappingsDistLogP_list:
          # obtain mappings active at X_l for this assignment
          siteLoc_set = set(filter(
            lambda(loc): loc in loc_set,
            ploc_set,
          ))
          # count number {ref, alt} in this vector
          numRef = 0
          numAlt = 0
          for loc in ploc_set:
            if loc.chrName == chrName:
              if pos in loc.edits_map:
                numAlt += 1
              elif insideInterval(pos, loc.pos, loc.aend):
                numRef += 1

          (_logP_v, _logP_vp) = InferenceEngine.computePostPredLogP(
            numRef,
            numAlt,
          )
          siteLogP_v_list.append(_logP_v  + readLogP)
          siteLogP_vp_list.append(_logP_vp  + readLogP)

        siteLogP_v  = logsumexp(siteLogP_v_list)
        siteLogP_vp = logsumexp(siteLogP_vp_list)
        editSiteInfo_map[XeditSite] = (
          siteLogP_v,
          siteLogP_vp,
          finalNumAlt,
          finalNumRef,
          altAllele,
          cid,
        )

    return editSiteInfo_map

  def __getReadMapq__(
    self,
    mappings,
    activeReadScope_map,
    eintervals_map,
    editSiteInfo_map,
  ):
    # { <rid> : mapq }
    mapq_map = {}

    logger.info('passing through all multimapped reads')
    step = max(1, len(mappings) / 10)
    for (i, (rid, loc)) in enumerate(mappings.items()):
      if i % step == 0:
        logger.info('  - {0}, {1}'.format(i, len(mappings)))

      # read may have no candidates in active clouds, count as 0 optimal
      # alignments
      activeLoc_list = activeReadScope_map[rid]
      if loc not in activeLoc_list:
        mapq_map[rid] = 0
        continue

      # compute logP of each alignments edits given R-i
      allTargetLogP = []

      srcLogP = None
      edits_list = []
      for tloc in activeLoc_list:
        edits_list.append(tloc.ED)
        # compute probs over all sites at this target mapping
        siteHits = eintervals_map[tloc.chrName].find(
          tloc.pos,
          tloc.aend - 1,
        )
        siteLogP = 0.0
        allEditSites_set = set()
        for (editSite, ref_set, alt_set) in siteHits:
          allEditSites_set.add(editSite)
          (eChrName, epos) = editSite
          assert tloc.chrName == eChrName
          # obtain P(X_l | R_l-)
          if editSite in editSiteInfo_map:
            (logP_v, logP_vp) = editSiteInfo_map[editSite][:2]
          else:
            (logP_v, logP_vp) = (None, None)
          projAlt = (epos in tloc.edits_map)
          projRef = (epos not in tloc.edits_map)
          siteLogP += logsumexp(InferenceEngine.computeSiteLogPs(
            projRef,
            projAlt,
            logP_v=logP_v,
            logP_vp=logP_vp,
          ))
        readEdits_set = set(map(
          lambda(pos): (tloc.chrName, pos),
          tloc.edits_map.keys(),
        ))
        # check all of this mappings edits are present in site hits
        assert readEdits_set.issubset(allEditSites_set)

        if tloc == loc:
          srcLogP = siteLogP
        allTargetLogP.append(siteLogP)

      assert srcLogP != None
      logZ = logsumexp(allTargetLogP)
      p = np.exp(srcLogP - logZ)
      if p == 1.:
        mapq = 60
      else:
        mapq = min(60, int(-10 * math.log(1-p, 10)))

      assert mapq >= 0 and mapq <= 60

      #print 'rid', rid
      #print '  - scores', sorted(zip(edits_list, allTargetLogP))
      #print '  - src score', (loc.ED, srcLogP)
      #print '  - logZ', logZ
      #print '  - err p', 1-p
      #print '  - mapq', mapq

      mapq_map[rid] = mapq

    #cloudFeats_map = self.extractCloudFeats(mappings)
    #naivemapq_map = self.getMapqNaive(mappings, cloudFeats_map)
    #print 'number high mapq with naive', len(filter(
    #  lambda(mapq): mapq > 20,
    #  naivemapq_map.values(),
    #))
    #print 'number high mapq with smart', len(filter(
    #  lambda(mapq): mapq >= 18,
    #  mapq_map.values(),
    #))

    return mapq_map

  @staticmethod
  def getCloudPenalty(
    cloud,
    activeMappings,
    cloudFeats,
  ):
    if not cloudFeats.isActive:
      return -9000.

    # eliminate edits
    cloudLogP = InferenceEngine.computeEditsLogP(cloudFeats.edits_map)
    # factor in penalty of unpaired reads in this cloud
    qidCounter = Counter(map(lambda(loc): loc.qid, activeMappings))
    numUnpaired = len(filter(
      lambda(qid, cnt): cnt == 1,
      qidCounter.items()
    ))
    cloudLogP += math.log(0.01) * numUnpaired

    # normalize by the length of the clouds
    qual = 1. * cloudLogP / cloudFeats.size

    # adjust for baseline due to sequencing error
    basequal = math.log(0.01) * cloudFeats.numReads / cloudFeats.size
    aqual = min(0, qual - basequal)
    #penalty = aqual 
    penalty = qual

    return penalty

  @staticmethod
  def getFixedEditInfo(
    cloud,
    cloudFeats,
  ):
    editSiteInfo_map = {}
    for (editPos, baseCounter) in cloudFeats.edits_map.items():
      editSite = (cloud.chrName, editPos)
      total = sum(baseCounter.values())
      numRef = baseCounter['ref']
      numAlt = total - numRef
      if numAlt == 0:
        continue
      # obtain alternate allele
      (altAllele, cnt) = filter(
        lambda(k,v): k != 'ref',
        baseCounter.most_common(),
      )[0]
      assert cnt > 0

      # compute P(X|R) for each site
      (logP_v, logP_vp) = InferenceEngine.computePostPredLogP(numRef, numAlt)
      editSiteInfo_map[editSite] = (
        logP_v,
        logP_vp,
        numAlt,
        numRef,
        altAllele,
        cloud.cid,
      )
    return editSiteInfo_map

  def __getCloudQuality__(
    self,
    mappings,
    cloudFeats_map,
    activeCid_set,
  ):
    # { <cid> : mapq }
    cloudq_map = {}

    pairFeats_map = self.extractPairFeats(mappings)

    logger.info('passing through all clouds')
    step = max(1, len(self.allCid_set) / 10)
    for (i, cid) in enumerate(self.allCid_set):
      if i % step == 0:
        logger.info('  - {0}, {1}'.format(i, len(self.allCid_set)))

      # mark inactive clouds as very low quality
      if cid not in activeCid_set:
        cloudq_map[cid] = -9000.
        continue

      cloud = self.cloud_map[cid]
      cloudFeats = cloudFeats_map[cid]
      activeMappings = cloud.getActiveMappings(mappings)
      penalty = InferenceEngine.getCloudPenalty(
        cloud,
        activeMappings,
        cloudFeats,
      )
      cloudq_map[cid] = penalty

    return cloudq_map

  def __getCloudMapq__(
    self,
    mappings,
    activeCid_set,
  ):

    # { <cid> : mapq }
    cloudq_map = {}

    (homoCloudTarget_map, _) = self.getHomoCloudTargets(mappings)

    logger.info('passing through all clouds')
    step = max(1, len(self.allCid_set) / 10)
    for (i, cid) in enumerate(self.allCid_set):
      if i % step == 0:
        logger.info('  - {0}, {1}'.format(i, len(self.allCid_set)))

      # mark inactive clouds as 0 prob of existing
      if cid not in activeCid_set:
        cloudq_map[cid] = -9000.
        continue

      # mark as unique if no homo targets
      if len(homoCloudTarget_map[cid]) == 0:
        cloudq_map[cid] = 0.
      # renormalize against homo targets
      else:
        srcLogP = homoCloudTarget_map[cid][0][2]
        logZ = logsumexp(map(
          lambda(_1, _2, logP): logP,
          homoCloudTarget_map[cid],
        ))
        logP = srcLogP - logZ
        cloudq_map[cid] = logP

    return cloudq_map

  def getComputedQueries(
    self,
    mappings,
    uid=None,
  ):

    logger.info('computing all probability queries')


    # initial book keeping
    #-------------
    logger.info('building bookeeping structures')

    logger.info('  - extracting cloud features')
    cloudFeats_map = self.extractCloudFeats(mappings)

    logger.info('  - active read scopes')
    # build new read scope reduced to only active clouds
    activeCid_set = self.getActiveCids(cloudFeats_map)
    activeReadScope_map = {}

    def isLiveMapping(loc):
      cf = cloudFeats_map[loc.cid]
      return (
        # must be active cloud
        loc.cid in activeCid_set and
        # must be within cloud coordinates
        loc.pos >= cf.begin - 200 and
        loc.aend <= cf.end + 200
      )

    for (rid, loc) in mappings.items():
      # only deem it active if it is within live coordinates of
      # get all active locations
      activeLoc_list = filter(
        lambda(loc): isLiveMapping(loc),
        self.readScope_map[rid],
      )
      activeReadScope_map[rid] = activeLoc_list

    # build intervals of all candidate variants in active clouds
    logger.info('  - active edit sites')
    eintervals_map = defaultdict(IntervalTree)
    seenSite_set = set()
    for cid in activeCid_set:
      cloud = self.cloud_map[cid]
      edits_map = cloudFeats_map[cid].edits_map
      activeMappings = cloud.getActiveMappings(mappings)
      editsReads_map = cloud.getEditsReadsMap(activeMappings)
      for (pos, baseCounter) in edits_map.items():
        total = sum(baseCounter.values())
        #if total == 0:
        #  continue
        numRef = baseCounter['ref']
        numAlt = total - numRef
        ref_set = set()
        alt_set = set()
        for loc in editsReads_map[pos]:
          if pos in loc.edits_map:
            alt_set.add(loc)
          else:
            ref_set.add(loc)

        site = (cloud.chrName, pos)
        assert site not in seenSite_set
        seenSite_set.add(site)

        assert numRef >= 0 and numAlt >= 0
        assert len(ref_set) == numRef
        assert len(alt_set) == numAlt
        eintervals_map[cloud.chrName].insert(
          pos - 1,
          pos + 1,
          (site, ref_set, alt_set),
        )
    logger.info('  - done')

    # compute cloud quality query P(R_c = R_map) for all active clouds
    #-------------
    logger.info('computing cloud quality scores P(R_c)')
    cloudQuality_map = self.__getCloudQuality__(
      mappings,
      cloudFeats_map,
      activeCid_set,
    )

    # build P(X_l | R_l-) for all sites
    #-------------
    logger.info('computing P(X_l | R_l-) for all sites')
    cache_fname = '{0}_editinfo.p'.format(uid)
    if uid != None and os.path.isfile(cache_fname):
      logger.info('  - loaded from cache')
      editSiteInfo_map = util.loadPickle(cache_fname)

    else:
      editSiteInfo_map = self.__getEditSiteProbs__(
        mappings,
        cloudQuality_map,
        cloudFeats_map,
        activeCid_set,
        activeReadScope_map,
        eintervals_map,
      )
      if uid:
        logger.info('  - dumping to cache')
        util.writePickle(
          cache_fname,
          editSiteInfo_map,
        )
          

    # compute read quality query P(R_n | R_l-) for all reads
    #-------------
    logger.info('computing P(R_n | R_l-) for all reads')
    cache_fname = '{0}_readmapq.p'.format(uid)
    if uid != None and os.path.isfile(cache_fname):
      logger.info('  - loaded from cache')
      mapq_map = util.loadPickle(cache_fname)

    else:
      mapq_map = self.__getReadMapq__(
        mappings,
        activeReadScope_map,
        eintervals_map,
        editSiteInfo_map,
      )
      if uid:
        logger.info('  - dumping to cache')
        util.writePickle(
          cache_fname,
          mapq_map,
        )

    # approximate cloud mapq query P(E_c | R_c- = R_map) for all active clouds
    #-------------
    logger.info('computing cloud mappability scores P(E_c | R_c-)')
    cloudMapq_map = self.__getCloudMapq__(
      mappings,
      activeCid_set,
    )

    return (
      editSiteInfo_map,
      mapq_map,
      cloudQuality_map,
      cloudMapq_map,
    )

#-------------------------------------------------------------------------
# cloud mapq query computation
#-------------------------------------------------------------------------
  def getHomoCloudTargets(self, mappings):
  
    cloudFeats_map = self.extractCloudFeats(mappings)

    # map of all homo cloud IDs to the set of potential homo target clouds
    # and their corresponding cloud probs
    homoCloudTarget_map = defaultdict(list)
    # set of all cloud IDs containing sufficient number of ambig reads
    ambigCid_set = set()

    for (i, (srcCid, srcCloud)) in enumerate(self.cloud_map.items()):
      activeMappings = srcCloud.getActiveCands(mappings)
      cloudAmbigLoc_set = map(
        lambda(loc): loc.rid in self.ambigRid_set,
        activeMappings,
      )
      if len(cloudAmbigLoc_set) >= 20:
        ambigCid_set.add(srcCid)

      currFeats = cloudFeats_map[srcCid]

      # come up with target clouds to move groups of reads to
      targetMove_map = defaultdict(set)
      targetCloudCounter = Counter()
      for loc in activeMappings:
        rid = loc.rid
        seenCid_set = set()
        # traverse possible targets in reverse order of alignment score so
        # if there are multiple mappings to target cloud will pick the one
        # with highest alignment score
        for tloc in sorted(
          self.readScope_map[rid],
          key=lambda(l):l.AS,
          reverse=True,
        ):
          if tloc.cid in seenCid_set:
            continue
          seenCid_set.add(tloc.cid)
          targetCloudCounter[tloc.cid] += 1
          targetMove_map[tloc.cid].add(tloc)

      # skip clouds with no moves
      if len(targetMove_map) == 0:
        continue

      # exclude the source cloud
      del targetCloudCounter[srcCid]
      del targetMove_map[srcCid]

      targetCloudCounter = Counter(dict(filter(
        lambda(tcid, readCount): (
          (not cloudFeats_map[tcid].isActive) and
          (1. * readCount / cloudFeats_map[srcCid].numReads > 0.90)
        ),
        targetCloudCounter.items(),
      )))
      # obtain all rids the homo targets have in common
      def getRids(loc_set):
        return set(map(lambda(loc): loc.rid, loc_set))

      commonRid_set = set(getRids(activeMappings))
      homoTargetCid_set = set()
      for (targetCid, _) in targetCloudCounter.most_common():
        targetRid_set = getRids(targetMove_map[targetCid])
        if 1. * len(targetRid_set & commonRid_set) / cloudFeats_map[srcCid].numReads <= 0.90:
          break

        homoTargetCid_set.add(targetCid)
        commonRid_set &= targetRid_set

      if len(homoTargetCid_set) > 0:

        # rescore the source cloud with only the common mappings
        def filtLocs(loc_set, rid_set):
          return set(filter(
            lambda(loc): loc.rid in rid_set,
            loc_set,
          ))

        filtSrcMappings = filtLocs(activeMappings, commonRid_set)
        updatedSrcEdits = srcCloud.getEditsCounterMap(filtSrcMappings)
        srcLogP = InferenceEngine.computeEditsLogP(updatedSrcEdits)

        homoCloudTarget_map[srcCid].append(
          (srcCid, None, srcLogP)
        )
        for targetCid in homoTargetCid_set:

          targetCloud = self.cloud_map[targetCid]
          filtTargetMappings = filtLocs(targetMove_map[targetCid], commonRid_set)
          targetEdits = targetCloud.getEditsCounterMap(filtTargetMappings)

          targetLogP = InferenceEngine.computeEditsLogP(targetEdits)

          newLoc_set = targetMove_map[targetCid]
          propmappings = dict(mappings)
          for loc in newLoc_set:
            propmappings[loc.rid] = loc

          propFeats = targetCloud.getFeats(propmappings)

          # score the target cloud
          homoCloudTarget_map[srcCid].append(
            (targetCid, propFeats, targetLogP)
          )
      else:
        homoCloudTarget_map[srcCid] = []

    return (homoCloudTarget_map, ambigCid_set)

  def getExcludeCids(self, mappings):

    def getConfEdits(feats):
      confEdits_list = []
      for (pos, baseCounter) in feats.edits_map.items():
        total = sum(baseCounter.values())
        numRef = baseCounter['ref']
        numAlt = total - numRef
        (logP_v, logP_vp) = InferenceEngine.computeSiteLogPs(numRef, numAlt)
        logZ = logsumexp([logP_v, logP_vp])
        logP_v = logP_v - logZ
        P_v = np.exp(logP_v)
        entry = (
          P_v,
          numRef,
          numAlt,
          pos,
        )
        if P_v > 0.5:
          confEdits_list.append(entry)
      return confEdits_list

    cloudFeats_map = self.extractCloudFeats(mappings)
    (homoCloudTarget_map, _) = self.getHomoCloudTargets(mappings)
    excludeCid_set = set()
    # compute sequence identity probability using edit distance
    for (scid, homoTarget_list) in homoCloudTarget_map.items():

      if len(homoTarget_list) == 0:
        continue

      srcFeats = cloudFeats_map[scid]
      srcLogP = homoTarget_list[0][2]
      logZ = logsumexp(map(
        lambda(_1, _2, logP): logP,
        homoTarget_list,
      ))
      prob = np.exp(srcLogP - logZ)
      exclude = (
        # src has required mappability
        prob < 0.55 or
        # src has 4 or more mixed barcodes
        (srcFeats.numMixedBar > 3)
      )
      if exclude:
        excludeCid_set.add(scid)

    return excludeCid_set

  def isAmbigRead(self, rid):
    topCount = sorted(
      Counter(map(
        lambda(loc): loc.AS,
        self.readScope_map[rid]
      )).items(),
      reverse=True,
    )[0][1]
    return (topCount > 1)

#-------------------------------------------------------------------------
# main run loop
#-------------------------------------------------------------------------
  def run(self, backdoor=False):

    # seed with some top alignment score mappings
    mappings = self.getTopASMappings()

    cloudFeats_map = self.extractCloudFeats(mappings)
    pairFeats_map = self.extractPairFeats(mappings)

    logger.info('eval initialization')
    self.evaluate(mappings, cloudFeats_map)

    # need to compute base score of the current mapping
    baseCloudScore_map = \
      {cid : self.computeCloudScore(cloudFeats) \
        for (cid, cloudFeats) in cloudFeats_map.items()}
    baseCloudGlobalScore = sum(baseCloudScore_map.values())
    basePairGlobalScore = sum(map(
      lambda(ptype): self.computePairScore(ptype),
      pairFeats_map.values(),
    ))
    fieldScore = (
      baseCloudGlobalScore +
      basePairGlobalScore 
    )
    step = max(1, len(self.cloud_map) / 10)
    # choose initial temperature
    #--------------------
    # get energy deltas of some uphill transitions
    Ds  = []
    for (i, srcCid) in enumerate(self.allCid_set):
      updates = self.moveCloud(
        srcCid,
        fieldScore,
        mappings,
        cloudFeats_map,
        pairFeats_map,
      )
      (nfieldScore, _, _, _) = updates
      if nfieldScore < fieldScore:
        Ds.append(nfieldScore - fieldScore)
      if i > step:
        break

    logger.info('number of initial transitions: {0}'.format(len(Ds)))

    # tune for acceptance probability
    Ds = np.array(Ds)
    T0 = np.mean(Ds) / np.log(0.1)

    # loop over all clouds until cooled
    #--------------------
    # decay and freeze from t0
    Ts = np.linspace(T0, 0, 4)
    Ts = np.append(Ts, 0)
    for T in Ts:
      
      totalUphill = 0
      numAccept = 0
      for (i, srcCid) in enumerate(sorted(self.allCid_set)):
        if i % step == 0:
          logger.info('  - ({0}/{1})'.format(i, len(self.cloud_map)))

        # get new field assignment
        updates = self.moveCloud(
          srcCid,
          fieldScore,
          mappings,
          cloudFeats_map,
          pairFeats_map,
        )
        # update state
        (
          nfieldScore, 
          updatedMappings_map,
          updatedPairFeats_map,
          updatedCloudFeats_map,
        ) = updates
        uncond = (nfieldScore >= fieldScore)
        accept = (
          uncond or
          (
            T != 0. and
            np.exp((nfieldScore - fieldScore) / T) >
            random.random()
          )
        )
        if not uncond:
          totalUphill += 1
          numAccept += (accept)
        if accept:
          fieldScore = nfieldScore
          mappings.update(updatedMappings_map)
          cloudFeats_map.update(updatedCloudFeats_map)
          pairFeats_map.update(updatedPairFeats_map)

      logger.info('finish cloud iteration')
      self.evaluate(mappings, cloudFeats_map)

    self.finalMappings = mappings

  def saveBackdoor(self, backdoor_fname):
    #if not os.path.isfile(backdoor_fname):
    info_map = {}
    for (rid, loc) in self.finalMappings.items():
      info_map[rid] = (
        loc.chrName,
        loc.pos,
        loc.aend,
        loc.AS,
      )
    logger.info('saving alignments')
    util.writePickle(
      backdoor_fname,
      info_map,
    )

  def backdoor(self, backdoor_fname):
    logger.info('loading alignments')
    info_map = util.loadPickle(backdoor_fname)
    self.finalMappings = {}
    for (rid, hint) in info_map.items():
      for loc in self.readScope_map[rid]:
        lochint = (
          loc.chrName,
          loc.pos,
          loc.aend,
          loc.AS,
        )
        if lochint == hint:
          self.finalMappings[rid] = loc
    return

