import os
import sys
import subprocess
import shlex
import logging

from ..common import util

logger = logging.getLogger(__name__)

#runCount = 0

# obtain picard *jar path
picardJar_path = os.environ.get('PICARDPATH')

def checkPicard():
  # check if picard jar path populated
  if picardJar_path == None:
    logger.error('$PICARDPATH not set')
    logger.error('  - specify directory to set of picard *jar files in shell env')
    sys.exit(1)
  # test if it's valid
  else:
    test_path = os.path.join(picardJar_path, 'MergeSamFiles.jar')
    if not os.path.isfile(test_path):
      logger.error('$PICARDPATH={0} not a valid picard *jar directory'.format(picardJar_path))
      logger.error('  - required jar {0} not found'.format(test_path))
      sys.exit(2)

def __launchCmd__(cmd):
  checkPicard()
  logger.info('cmd: {0}'.format(cmd))
  pp = subprocess.Popen(
    shlex.split(cmd),
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
  )
  pp.wait()
  logger.debug(pp.stdout.read())
  logger.debug(pp.stderr.read())
  assert pp.returncode == 0

def __getTmpDir__(uid):
  # unique tempdir
  #global runCount
  #runCount += 1
  tmpDir_path = os.path.join(
    os.getcwd(),
    'tmp-picard-{0}'.format(uid),
    #'tmp-picard{0}-{1}'.format(uid, runCount),
  )
  util.mkdir_p(tmpDir_path)
  return tmpDir_path

def mergeBams(
  inputBam_list,
  outputBam_path,
  order="coordinate",
):
  inputBam_str = ''
  for path in inputBam_list:
    assert os.path.isfile(path)
    inputBam_str += 'INPUT={0} '.format(path)
  cmd = 'java -Xmx2g -jar {0}/MergeSamFiles.jar \
    SORT_ORDER="{1}" \
    {2} \
    OUTPUT={3} \
    TMP_DIR={4}'.format(
      picardJar_path,
      order,
      inputBam_str,
      outputBam_path,
      __getTmpDir__('merge'),
    )
  __launchCmd__(cmd)

def sortSam(
  input_path,
  output_path,
  order="coordinate",
):

  cmd = 'java -Xmx2g -jar {0}/SortSam.jar \
    TMP_DIR={1} \
    INPUT={2}  \
    OUTPUT={3} \
    SORT_ORDER="{4}"'.format(
      picardJar_path,
      __getTmpDir__('sort'),
      input_path,
      output_path,
      order,
    )
  __launchCmd__(cmd)

def createBAMIndex(bam_path):
  assert bam_path.endswith('.bam')
  assert os.path.isfile(bam_path)
  output_path = bam_path + '.bai'
  cmd = 'java -Xmx2g -jar {0}/BuildBamIndex.jar \
    TMP_DIR={1} \
    INPUT={2} \
    OUTPUT={3}'.format(
      picardJar_path,
      __getTmpDir__('index'),
      bam_path,
      output_path,
    )
  __launchCmd__(cmd)
  return output_path

