import os
import sys
import getopt
import shutil
import numpy as np
import gc
import re
import cPickle as pickle
import json
import logging
import logging.config
from collections import Counter, defaultdict

from mlib.common import util, job, plotter
from mlib.config import Config, ConfigFactory
from mlib.worker.batch import Batch
from mlib.toolflow import picard
logger = None

#-------------------------------------------------------------------------
# set up logging
#-------------------------------------------------------------------------
def setup():
  loggingConfig_map = { 
    'version': 1,
    'disable_existing_loggers': False,
    'formatters': {
      'verbose': {
        'format' : '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
      },  
      'simple': {
          'format': '%(levelname)s %(message)s'
      },  
    },  
    'handlers': {
      'stdout': {
        'level': 'INFO',
        'level': 'DEBUG',
        'class': 'logging.StreamHandler',
        'formatter': 'verbose',
      },  
      'debugfile': {
        'level': 'DEBUG',
        'class': 'logging.FileHandler',
        'filename': os.path.join(os.getcwd(), 'debug.log'),
        'formatter': 'verbose',
      },  
    },  
    'loggers': {
      '': {
        'handlers': ['stdout', 'debugfile'],
        'level': 'INFO',
        'level': 'DEBUG',
        'propagate': False,
      },  
    }   
  }
  logging.config.dictConfig(loggingConfig_map)
  global logger
  logger = logging.getLogger(__name__)

#-------------------------------------------------------------------------
# work flows
#-------------------------------------------------------------------------
def stats(scratch_path, config):
  cf = ConfigFactory(config)
  # collect stencils from sample wells first
  batch = Batch(
    scratch_path,
    cf.getStatsConfigs(),
    'collect-stats',
    dict(
      maxmem=6,
    )
  )
  batch.run(checkfirst=True)

def estimate(scratch_path, config):
  cf = ConfigFactory(config)
  # obtain features from sample wells
  batch = Batch(
    scratch_path,
    cf.getStatsConfigs(limit=10),
    'collect-stats',
    dict(
      maxmem=6,
    )
  )
  batch.run(checkfirst=True)

  # get feature files
  batch = Batch(
    scratch_path,
    cf.getEstimatorConfigs(),
    'estimate-params',
    dict(
      maxmem=6,
    )
  )
  batch.run(checkfirst=True)

def sim(scratch_path, config):
  cf = ConfigFactory(config)
  # collect stencils from sample wells first
  batch = Batch(
    scratch_path,
    cf.getStatsConfigs(limit=100),
    'collect-stats',
    dict(
      maxmem=6,
    )
  )
  batch.run(checkfirst=True)

  # simulate wells using stencils
  batch = Batch(
    scratch_path,
    cf.getSimConfigs(),
    'sim-wells',
    dict(
      maxmem=6,
    )
  )
  batch.run(checkfirst=True)

def align(
  scratch_path,
  config,
  mode,
):
  # check params estimated
  if not os.path.isfile(config.params_path):
    logger.error('parameters required to align, run estimate first')
    sys.exit(1)
  uid = 'align-{0}'.format(mode)
  cf = ConfigFactory(config)

  batch = Batch(
    scratch_path,
    cf.getAlignConfigs(mode),
    uid,
    dict(
      maxmem=10,
    )
  )
  if mode == 'sim' and batch.numJobs == 0:
    logger.error('no simulated wells to align, run sim first')
    sys.exit(2)
  batch.run(checkfirst=True)

  # compile simulation results
  if mode == 'sim':
    evalAlignments(config.simAlignments_path)

def evalAlignments(results_path):
  fname_list = filter(
    lambda(f): f.endswith('results.json'),
    os.listdir(results_path),
  )
  logger.info('parsing {0} alignment results'.format(len(fname_list)))

  # tally up results form all wells
  tallies_map = defaultdict(list)
  for fname in fname_list:
    path = os.path.join(results_path, fname)
    barcodeResults_list = util.jsonLoadASCII(path)
    for resultsFrag_map in barcodeResults_list:
      totalResolveReads = resultsFrag_map['total_resolve_reads']
      totalReads = resultsFrag_map['total_reads']
      for (k, v) in resultsFrag_map.items():
        # add raw counts
        tallies_map[k].append(v)
        # add percentage values for all read counts to the tallies
        if k.startswith('rreads'):
          pctk = 'pct_' + k
          pctv = 1. * v / totalResolveReads
          tallies_map[pctk].append(pctv)
        elif k.startswith('aallreads'):
          pctk = 'pct_' + k
          pctv = 1. * v / totalReads
          tallies_map[pctk].append(pctv)

  # cumulative results for all reads
  logger.info('all reads accuracy results (mapq >= 10)')
  for (name, k) in [
    ('baseline', 'pass1'),
    ('naive',    'topAS'),
    ('rfa',      'rfa'),
    ('oracle',   'oracle'),
  ]:
    key = 'pct_aallreads_{0}_correct_mmq'.format(k)
    logger.info('  - {0:10s}: {1:4.4f}'.format(
      name, 
      np.mean(tallies_map[key]),
    ))

  # cumulative results for multimapped reads in abbreviated reference
  # NOTE baseline not applicable
  logger.info('multimapped reads (within abbreviated reference) accuracy results (mapq >= 10)')
  logger.info('NOTE baseline not applicable')
  for (name, k) in [
    #('baseline', 'baseline'),
    ('naive',    'topAS'),
    ('rfa',      'rfa'),
    ('oracle',   'oracle'),
  ]:
    key = 'pct_rreads_{0}_correct_mmq'.format(k)
    logger.info('  - {0:10s}: {1:4.4f}'.format(
      name, 
      np.mean(tallies_map[key]),
    ))

  # cumulative results for error rate within abbreviated reference
  logger.info('multimapped reads (within abbreviated reference) error rate results (mapq >= 10)')
  logger.info('NOTE baseline, oracle not applicable')
  for (name, k) in [
    #('baseline', 'baseline'),
    ('naive',    'topAS'),
    ('rfa',      'rfa'),
    #('oracle',   'oracle'),
  ]:
    key = 'pct_rreads_{0}_incorrect_hmq'.format(k)
    logger.info('  - {0:10s}: {1:4.4f}'.format(
      name, 
      np.mean(tallies_map[key]),
    ))
 
  # compute {naive,rfa,oracle} gains over baseline approach

  # compute rfa gap from oracle

  # produce simulation plots
  plots_path = os.path.join(
    results_path,
    'simulation.plots',
  )
  util.mkdir_p(plots_path)
  logger.info('generating simulation plots')
  logger.info('  - dumping to: {0}'.format(plots_path))
  with util.cd(plots_path):
    plotter.plotHistograms(
      [
        ('naive',    tallies_map['pct_rreads_topAS_correct_mmq']),
        ('RFA',      tallies_map['pct_rreads_rfa_correct_mmq']),
        ('oracle',   tallies_map['pct_rreads_oracle_correct_mmq']),
      ],
      {
        'xlabel' : 'percent correct',
        'pct' : True,
        'ylabel' : 'freq', 
        'fname'  : 'pct_correct_rslv',
        'numBins' : 50,
        'save' : 'png',
        'naive' : {
          'alpha' : 0.5,
          'color' : 'b',
        },
        'RFA' : {
          'alpha' : 0.5,
          'color' : 'g',
        },
        'oracle' : {
          'alpha' : 0.5,
          'color' : 'r',
        },
      }
    )

    plotter.plotHistograms(
      [
        ('baseline', tallies_map['pct_aallreads_pass1_correct_mmq']),
        ('naive',    tallies_map['pct_aallreads_topAS_correct_mmq']),
        ('RFA',      tallies_map['pct_aallreads_rfa_correct_mmq']),
        ('oracle',   tallies_map['pct_aallreads_oracle_correct_mmq']),
      ],
      {
        'xlabel' : 'percent correct',
        'pct' : True,
        'ylabel' : 'freq', 
        'fname'  : 'pct_correct_all',
        'numBins' : 50,
        'save' : 'png',
        'baseline' : {
          'alpha' : 0.2,
          'color' : '0.5',
        },
        'naive' : {
          'alpha' : 0.5,
          'color' : 'b',
        },
        'RFA' : {
          'alpha' : 0.5,
          'color' : 'g',
        },
        'oracle' : {
          'alpha' : 0.5,
          'color' : 'r',
        },
      }
    )

#-------------------------------------------------------------------------
# main
#-------------------------------------------------------------------------
def main(argv):
  help_str = '''
  Usage: rfa.py <cmd> [opts]

  cmd:
    * stats           generate stats for all wells
    * sim             generate simulated wells
    * estimate        estimate models from sample wells
    * align           align sample wells with RFA
  all opts:
    -b, --build       scratch directory name to use, default: $TMPDIR
    -c, --config      (required) path to configuration *json file to use
        --local       run all jobs locally (and serially)

  <align> opts:
        --mode        {sample, sim} default: sample

    -h, --help        print this message
  '''
  argShortName_str = 'hc:b:'
  argLongName_list = [
    'help',
    'build=',
    'config=',
    'local',
    'mode=',
  ]

  # parse cmd
  #----------------------
  cmd = None
  if len(argv) < 2:
    print 'no <cmd> provided'
    sys.exit(1)
  else:
    cmd = argv[1]
    if cmd in ['-h', '--help', 'help']:
      print help_str
      sys.exit(0)
    print 'CMD: {0}'.format(cmd)
    assert cmd in ['sim', 'estimate', 'align', 'stats', ]

  # parse options + setup
  #----------------------
  try:
    opts, args = getopt.getopt(argv[2:],argShortName_str, argLongName_list)
  except getopt.GetoptError as err:
    print str(err)
    sys.exit(2)

  mode = 'sample'
  scratch_path = None
  config_path = None
  for flag, value in opts:
    if flag in ('-h', '--help'):
       print help_str
       sys.exit(0)
    elif flag in ('-c', '--config'):
      config_path = value
    elif flag in ('-b', '--build'):
      scratch_path = value
    elif flag in ('-b', '--local'):
      job.local = True
    elif flag in ('--mode'):
      mode = value
      if cmd != 'align':
       print '--mode only to be specified in cmd <align>'
       print help_str
       sys.exit(3)

  if config_path == None:
    print 'config must be specified'
    sys.exit(1)
  if mode not in ['sample', 'sim']:
    print 'alignment mode {0} invalid'.format(mode)
    sys.exit(1)
  if not os.path.isfile(config_path):
    print 'config file {0} does not exist'.format(config_path)
    sys.exit(2)

  # use default scratch path if none specified
  if scratch_path == None:
    scratch_path = 'scratch.rfa'
    print 'build directory unspecified'
    print ' - using CWD/{0}'.format(scratch_path)

  scratch_path = os.path.abspath(scratch_path)
  config_path = os.path.abspath(config_path)

  # setup scratch + config
  #----------------------
  util.mkdir_p(scratch_path)
  os.chdir(scratch_path)
  setup()
  picard.checkPicard()
  job.setup(os.path.join(scratch_path, 'job.stdouts'))
  config = Config.fromJson(config_path)

  # run
  #----------------------
  if cmd == 'estimate':
    estimate(scratch_path, config)
  elif cmd == 'sim':
    sim(scratch_path, config)
  elif cmd == 'align':
    align(scratch_path, config, mode)
  elif cmd == 'stats':
    stats(scratch_path, config)

if __name__ == '__main__':
  main(sys.argv)

