import os
import sys
from collections import defaultdict, Counter
import logging

from batch import Worker

from ..common import util
from ..align.candcloud import CloudFeat_t
from ..bam.cloud_iter import CloudFinder
from ..toolflow import bowtie2

logger = logging.getLogger(__name__)

class WellStats(Worker):
  
  def run(self):
    def getFeats(cloud):
      return CloudFeat_t(
        size=cloud.getSize(),
        begin=cloud.minpos,
        end=cloud.maxpos,
        numReads=cloud.numreads,
        numFixed=None,
        numCand=None,
        lbarValid=cloud.lpart,
        rbarValid=cloud.rpart,
        numMixedBar=cloud.mixed,
        isActive=None,
        edits_map=None,
      )
    def getBedLine(cloud, barcode):
      return '{0}\t{1}\t{2}\t-\t{3}\n'.format(
        cloud.chrName,
        cloud.minpos,
        cloud.maxpos,
        '{0};{1}'.format(
          barcode,
          cloud.numreads,
        )
      )

    bowtieIndex_path = self.config.get('referenceBowtieIndex_path')

    inputs_map = self.config.inputs
    uid = 'l{0}_w{1:03d}'.format(
      inputs_map['laneID'],
      inputs_map['wellID'],
    )

    # pass1: obtain read alignments by aligning pass1 multiple times
    #----------------------
    logger.info('pass1 alignments')

    pass1Bam_path = 'pass1_align.bam'
    bowtie2.genAlignments(
    #bowtie2.genAlignmentsMP(
      bowtieIndex_path,
      inputs_map['fq1_path'],
      inputs_map['fq2_path'],
      pass1Bam_path,
      alignFlags='-D 20',
    )

    # save cloud info from pass1
    #----------------------
    logger.info('calling clouds')
    cf = self.config.cloudFinder_cls(
      pass1Bam_path,
      multiMapMode=True,
    )

    totalClouds = 0
    cloudStencil_list = []
    cloudFeats_list = []

    def getStencil(cloud):

      ridRead_map = dict(map(
        lambda(r): ((r.qname, r.is_read1), r),
        cloud.getReads(),
      ))

      readStencil_list = []
      seenQname_set = set()
      for (i, read) in enumerate(cloud.getReads()):
        # skip if already seen pair
        if read.qname in seenQname_set:
          continue
        seenQname_set.add(read.qname)
        mrid = (read.qname, not read.is_read1)
        assert mrid in ridRead_map
        mate = ridRead_map[mrid]

        # ensure first read at beginning
        if i == 0:
          assert read.pos == cloud.minpos

        isBarcode = cf.isBarcode(read)
        stencilInfo = (
          # r1pos
          read.pos - cloud.minpos,
          # r1 is reverse
          read.is_reverse,
          # r2pos 
          mate.pos - cloud.minpos,
          # r2 is reverse
          mate.is_reverse,
          # barcode flag
          isBarcode,
        )
        readStencil_list.append(stencilInfo)

      info_map = {
        'lpart' : cloud.lpart,
        'rpart' : cloud.rpart,
        'size'  : cloud.getSize(),
        'reads' : readStencil_list,
      }
      return info_map

    wellCounts_map = defaultdict(Counter)
    bedLines_list = []
    for (barcode, cloud_iter) in cf.getCloudsByBarcode():
      for (cid, cloud) in enumerate(cloud_iter):

        # count short read bases
        wellCounts_map[barcode]['sr-bases'] += sum(map(
          lambda(r): (r.aend - r.pos),
          cloud.getReads(),
        ))
        wellCounts_map[barcode]['lr-bases'] += cloud.getSize()
        wellCounts_map[barcode]['num-rc'] += 1

        totalClouds += 1

        # filter extreme outlier cloud properties
        if self.config.technology == 'moleculo' and (
          cloud.mixed or 
          cloud.getSize() > 15000 or
          1. * cloud.numreads * 100 / cloud.getSize() > 6
        ):
          continue
        elif self.config.technology == '10x' and (
          1. * cloud.numreads * 88 / cloud.getSize() > 2
        ):
          continue
      
        bedLines_list.append(
          (cloud.chrName, cloud.minpos, cloud.maxpos,
          getBedLine(cloud, barcode))
        )

        mapq_list = map(
          lambda(r): r.mapq,
          cloud.getReads(),
        )
        totalReads = len(mapq_list)
        totalHQ = len(filter(
          lambda(mq): mq >= 20,
          mapq_list,
        ))
        totalMQ = len(filter(
          lambda(mq): mq >= 10,
          mapq_list,
        ))

        # filter low cloud mappability
        if (
          1. * totalHQ / totalReads < 0.6 and
          1. * totalMQ / totalReads < 0.8
        ):
          continue

        # filter all reads which don't have pairs in this cloud
        qnameCounter = Counter(map(
          lambda(qname, _): qname,
          set(map(
            lambda(r): (r.qname, r.is_read1),
            cloud.getReads(),
          ))
        ))
        pairedQname_set = set(map(
          lambda(qname, v): qname,
          filter(
            lambda(qname, v): v == 2,
            qnameCounter.items(),
          )
        ))
        if len(pairedQname_set) < 5:
          continue
        cloud.filtQnames(pairedQname_set)
        if cloud.invalid:
          continue

        # save stencil and features for this cloud
        cloudStencil_list.append(
          getStencil(cloud)
        )
        cloudFeats_list.append(
          getFeats(cloud)
        )
          
    # save outputs
    #----------------------
    logger.info('total clouds: {0}'.format(totalClouds))
    logger.info('  - num stencils saved: {0}'.format(
      len(cloudStencil_list)))

    stencil_fname = '{0}_{1}.stencil.p'.format(
      inputs_map['laneID'],
      inputs_map['wellID'],
    )
    stats_fname = '{0}_{1}.stats.p'.format(
      inputs_map['laneID'],
      inputs_map['wellID'],
    )
    rawstats_fname = '{0}_{1}.rawstats.p'.format(
      inputs_map['laneID'],
      inputs_map['wellID'],
    )
    cloudBed_fname = '{0}_{1}.clouds.bed'.format(
      inputs_map['laneID'],
      inputs_map['wellID'],
    )
    stencil_path = os.path.join(self.config.wellStats_path, stencil_fname)
    stats_path   = os.path.join(self.config.wellStats_path, stats_fname)
    rawstats_path = os.path.join(self.config.wellStats_path, rawstats_fname)
    cloudBed_path = os.path.join(self.config.wellStats_path, cloudBed_fname)
    util.writePickle(stencil_path, cloudStencil_list)
    util.writePickle(stats_path, cloudFeats_list)
    util.writePickle(rawstats_path, wellCounts_map)
    with open(cloudBed_path, 'w') as f:
      f.write('#chrom\tchromStart\tchromEnd\tname\tscore\n')
      for (_1, _2, _3, bedline) in sorted(bedLines_list):
        f.write(bedline)
    return
  
