import os
import sys
import subprocess
import shlex
import time
import logging

import util

local = False
logger = logging.getLogger(__name__)

class JobInfo(object):
  def __init__(
    self,
    script=None,
    maxmem=8,
    walltimeStr='6:00:00',
    extraoptsStr='',
    tag='',
    rerunOnFail=False,
    rerunMem=None,
  ):
    assert script != None
    assert not (rerunOnFail and rerunMem == None)

    self.script      = os.path.join(os.getcwd(), script)
    self.maxmem      = maxmem     
    self.walltimeStr = walltimeStr
    self.extraoptsStr= extraoptsStr
    self.tag         = tag
    self.rerunOnFail = rerunOnFail
    self.rerunMem    = rerunMem
    self.isRerun     = False

    if not self.script.startswith('/'):
      self.script = './' + self.script

    # check existence and permissions
    if not os.path.isfile(script):
      logger.error('script {0} does not exist'.format(script))
      sys.exit(1)
    elif not os.access(script, os.X_OK):
      logger.error('script {0} does not have execute privileges'.format(script))
      sys.exit(2)

  def changeOptsForRerun(self):
    # increase mem and turn off rerun option for second run so it isnt'
    # rerun again upon failure
    self.isRerun = True
    self.rerunOnFail = False
    self.maxmem = self.rerunMem

#-------------------------------------------------------------------------
# global variables for maintaining and submitting jobs
#-------------------------------------------------------------------------
# set of all open jobs to finish 
openPID_set = set()
# corresponding job infos
jobInfo_map = {}

# local jobid we assign to discriminate scripts
jobid = 0

# default to user home  director
stdoutDir_path = os.path.join(
  os.path.expanduser('~'),
  'job.stdouts',
)

#-------------------------------------------------------------------------
# methods
#-------------------------------------------------------------------------

# configuration of job meanager with cluster
def setup(
  _stdoutDir_path,
):
  # set up stdout directory
  global stdoutDir_path
  stdoutDir_path = _stdoutDir_path 
  util.mkdir_p(stdoutDir_path)

def waitAll():

  global openPID_set
  global jobInfo_map

  def getQstatStatus(jid):
    qstat_cmd = 'qstat -j {0}'.format(jid)
    pp = subprocess.Popen(
      shlex.split(qstat_cmd),
      stdout=subprocess.PIPE,
      stderr=subprocess.PIPE,
    )
    pp.wait()
    # job not finished
    if pp.returncode == 0:
      return (False, None, None)
    # job finished, so need to call qacct to retrieve exitcode
    else:
      qacct_cmd = 'qacct -j {0}'.format(jid)
      qacct_exitcode = None
      # while qacct_exitcode != 0:
      #   pp = subprocess.Popen(
      #     shlex.split(qacct_cmd),
      #     stdout=subprocess.PIPE,
      #     stderr=subprocess.PIPE,
      #   )
      #   pp.wait()
      #   qacct_exitcode = pp.returncode
      # exitcode = None
      # failed = None
      # for line in pp.stdout.readlines():
      #   words = line.split()
      #   if line.startswith('exit_status'):
      #     exitcode = int(words[1])
      #   if line.startswith('failed'):
      #     failed = int(words[1])
      # 
      # return (True, failed, exitcode)
      return (True, False, 0)

  while len(openPID_set) > 0:
    finishPID_set = set()
    rerunJobInfo_set = set()
    for jid in openPID_set:
      (done, failed, exitcode) = getQstatStatus(jid)
      if done:
        finishPID_set.add(jid)
        jobInfo = jobInfo_map[jid]
        # rerun with more mem if first try 
        if exitcode != 0 and jobInfo.rerunOnFail:
          logger.error('job failed: ({0},{1})'.format(jobInfo.tag, jobInfo.script))
          logger.error('  - exit code: {0}'.format(exitcode))
          logger.error('  - failed: {0}'.format(failed))
          logger.error('  - rerunning')
          rerunJobInfo_set.add(jobInfo)
        # otherwise fail if second try
        elif exitcode != 0:
          logger.error('job failed final: ({0},{1})'.format(jobInfo.tag, jobInfo.script))
          logger.error('  - exit code: {0}'.format(exitcode))
          logger.error('  - failed: {0}'.format(failed))
        else:
          logger.info('job success: ({0},{1})'.format(jobInfo.tag, jobInfo.script))

    for jobInfo in rerunJobInfo_set:
      jobInfo.changeOptsForRerun()
      submit(jobInfo)

    openPID_set -= finishPID_set
    if len(finishPID_set) > 0:
      logger.info('waiting on {0} open jobs'.format(len(openPID_set)))

    if len(openPID_set) > 0:
      time.sleep(15)

def submit(jobInfo):

  # use for running and debugging locally
  if local:
    base = os.path.basename(jobInfo.script)
    tmpdir = os.path.join(
      'local.scratch',
      base + '.scratch',
    )
    util.mkdir_p(tmpdir)
    os.environ['TMPDIR'] = tmpdir
    logger.info('creating TMPDIR for local: {0}'.format(os.path.join(
      os.getcwd(),
      tmpdir,
    )))
    cmd = jobInfo.script
    logger.info('running job in local mode')
    logger.info('  - ' + cmd)
    pp = subprocess.Popen(
      shlex.split(cmd),
    )
    pp.wait()
    return

  global jobid

  def getQsubIDFromOut(line):
    try:
      assert line.startswith('Your job')
    except:
      print 'unexpected qsub return', line
      sys.exit(1)
    words = line.split()
    jid = int(words[2])
    return jid

  # create stdout path
  script_fname = os.path.basename(jobInfo.script)
  if jobInfo.isRerun:
    stdout_fname = 'rerun__' + script_fname + '.' + str(jobid)
  else:
    stdout_fname = script_fname + '.' + str(jobid)
  util.mkdir_p(stdoutDir_path)
  stdout_path = os.path.join(
    stdoutDir_path,
    stdout_fname,
  )
  jobid += 1

  cmd = 'qsub -o {0} -l h_vmem={1}G -l h_rt={2} {3} -j y {4}'.format(
    stdout_path,
    jobInfo.maxmem,
    jobInfo.walltimeStr,
    jobInfo.extraoptsStr,
    jobInfo.script,
  )
  logger.info('submitting job')
  logger.info('  - ' + cmd)
  pp = subprocess.Popen(
    shlex.split(cmd),
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
  )
  pp.wait()
  jid = getQsubIDFromOut(pp.stdout.read())
  logger.info('  - returned jid {0}'.format(jid))

  # obtain cluster assigned jobid for polling progress
  openPID_set.add(jid)
  jobInfo_map[jid] = jobInfo

  # sleep for short bit just in case
  time.sleep(0.4)

