#!/usr/bin/env python

import component_skeleton.main
import subprocess
import os, re
from string import *
import datetime, time
from pylab import *
import sys

def findFraglen(readfiles):

    fraglens = []
    rlist = readfiles.split()
    for rf in rlist:
        ##get Fragment length from the log of FragmentLength component
        fraglenRoot = os.path.join(os.path.split(rf)[0], '..', 'intermediate')
        interdir = os.listdir(fraglenRoot)
        for f in interdir:
            m = re.match('\S*.res$', os.path.join(fraglenRoot, f))
            if m:
                fraglenF = m.group()
                break
        fraglen = int(float(open(fraglenF).read().strip().split()[-1]))
        fraglens.append(fraglen)
        print '\nFound: %s with fragment length: %s\n' %(rf,fraglen)

    meanFraglen = mean(fraglens)
    maxFraglen= max(fraglens)
    print '\nFragment Length: %s' %fraglens

    return meanFraglen


def getBaseFreqs(f):
    """           
    This function computes base frequencies for input sequences
    """

    bd = {}
    bd['A'] = 0.
    bd['C'] = 0.
    bd['G'] = 0.
    bd['T'] = 0.

    for line in open(f):
        if line.startswith('>'):
            continue
        for b in bd:
            bases = list(line.strip())
            bd[b] += bases.count(b)

    print bd
    #normalize to frequencies
    tot = sum(bd.values())
    for i in bd:
        bd[i] /= tot

    print bd
    #take same frequencies for A and T or C and G respectively, because we are using double stranded DNA at the end.

    prec = 3
    ATfreq = round((bd['A']+bd['T'])/2., prec)
    GCfreq = round((bd['C']+bd['G'])/2., prec)

    #round frequencies to some precision and be sure that they sum up to exactly 1
    #round one frequency (AT or GC) down and the other up

    while (ATfreq + GCfreq) != 0.5:
        if (ATfreq + GCfreq) > 0.5:
            if rand(1) >= 0.5:
                ATfreq -= 1.0/(10**prec)
            else:
                GCfreq -= 1.0/(10**prec)
        else:
            if rand(1) >= 0.5:
                ATfreq += 1.0/(10**prec)
            else:
                GCfreq += 1.0/(10**prec)

    return ATfreq, GCfreq



def runMotevo(motevo_path, seqs, train_set, bg_train_set, WM, interm, pickled_sitesdict, pickled_idstats, pickled_idcoords, statsfile, regcov_dir, instance_name, queue_name, genome, markovorder, env):
    """
    runs Motevo
    """

    # membycore=8G, runtime=6:00:00, n_nodes=1                                                                                         
    qspecs = queue_name.split(',')
    mem_limit = qspecs[0].split('=')[1]
    time_limit = qspecs[1].split('=')[1]
    n_nodes = qspecs[2].split('=')[1]

    stdoutpath = os.path.join(os.path.split(interm)[0], 'motevo_job.stdout')
    stderrpath = os.path.join(os.path.split(interm)[0], 'motevo_job.stderr')

    jobFileContent = '\n'.join([
        '#! /bin/bash',
        '#SBATCH -n %i' %int(n_nodes), # number of cpus
        '#SBATCH --mem=%s' %mem_limit,
        '#SBATCH --time=%s' %time_limit,
        '#SBATCH -e %s' %stderrpath,
        '#SBATCH -o %s' %stdoutpath,
        '#SBATCH --job-name=%s' %instance_name,
        '#SBATCH --workdir=%s' %os.path.dirname(os.path.realpath(__file__)),
        'source %s' %env,
        './runmotevo.py %s %s %s %s %s %s %s %s %s %s %s %s %s' %(motevo_path, seqs, train_set, bg_train_set, WM, pickled_sitesdict, pickled_idstats, pickled_idcoords, statsfile, regcov_dir, genome, str(markovorder), interm)
    ])

    shellFilename = os.path.join(os.path.split(interm)[0], 'motevo_command.sh')
    with open(shellFilename, 'w') as outf:
        outf.write(jobFileContent)

    T1 = datetime.datetime.now()
    print T1
    sys.stdout.flush()

    ## submit job
    sub_message = subprocess.check_output(["sbatch", shellFilename])
    # Submitted batch job 4014401
    jobid = sub_message.split()[-1]
    job_finished = False

    print 'submitted', jobid
    sys.stdout.flush()

    ## wait until job finished
    wait_time = 30
    for x in range(1200):
        time.sleep(wait_time)
        # scontrol show jobid -dd 4014401
        p = subprocess.Popen(["scontrol", "show", "jobid",
                            "-dd", jobid], stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        qstat_message, error = p.communicate()
        if "JobState=RUNNING" not in qstat_message and "JobState=PENDING" not in qstat_message:
            job_finished = True
            time.sleep(wait_time)
            break

    if not job_finished:
        raise Exception("Your job is running for too long. Exiting.")
    else:
        print 'Job: ' + str(jobid) + ' finished with success.'


    print 'stdout:'
    sys.stdout.flush()
    os.system('cat %s' %stdoutpath)

    print 'stderr:'
    sys.stdout.flush()
    os.system('cat %s' %stderrpath)

    T2 = datetime.datetime.now()
    print T2
    print 'Running Time for %s: %s' %(instance_name, str(T2-T1))
    sys.stdout.flush()

    
    return job_finished


def splitRegCovs(regcov_dir, interm, fpj):
    """
    This function creates input files for drmaa. It creates files with fpj regcov file paths of.
    """

    count = 0
    i = 0

    froot = os.path.join(interm, 'regcovfile')

    for f in os.listdir(regcov_dir):
        if i >= fpj:
            i = 0
            o.close()
        if i == 0:
            count += 1
            outfile = froot + '.%i' %count
            o = open(outfile, 'w')

        name = os.path.join(regcov_dir, f)
        o.write(name + '\n')
        i += 1

    return froot, count



def create_slurm_submission_script(regcov_root, pickled_sd, pickled_ids, pickled_idc, plotdir, fraglen, minpost, peakstats_root, tfbsstats_root, count, instance_name, queue_name, environment):

    # membycore=8G, runtime=6:00:00, n_nodes=1                                                                                         
    qspecs = queue_name.split(',')
    mem_limit = qspecs[0].split('=')[1]
    time_limit = qspecs[1].split('=')[1]
    n_nodes = qspecs[2].split('=')[1]

    stderr_file = os.path.join(os.path.split(plotdir)[0], 'stderr')
    stdout_file = os.path.join(os.path.split(plotdir)[0], 'stdout')

    prog = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'combinePosteriors_faster.py')
    jobFileContent = '\n'.join([
        '#!/bin/bash',
        '#SBATCH -n 1',
        '#SBATCH --mem=%s' %mem_limit,
        '#SBATCH --time=%s' %time_limit,
        '#SBATCH -e %s' %stderr_file,
        '#SBATCH -o %s' %stdout_file,
        '#SBATCH --job-name=%s-combPost' %instance_name,
        '#SBATCH --array=1-%i' %count,
        '#SBATCH --workdir=%s' %os.path.dirname(os.path.realpath(__file__)),
        'source %s' %environment,
        '%s %s %s %s %s %s %s %s %s %s' %(prog, pickled_ids, pickled_idc, pickled_sd, regcov_root, tfbsstats_root, peakstats_root, plotdir, str(fraglen), str(minpost))])

    shellFilename = os.path.join(os.path.split(plotdir)[0], 'command.sh')
    with open(shellFilename, 'w') as outf:
        outf.write(jobFileContent)

    return shellFilename


def check_array_job_status(NUMBER_OF_JOBS, job_id):

    # scontrol returns something like that for every array job:
    # 'JobId=55226883 ArrayJobId=55226883 ArrayTaskId=1 JobName=enrichmentScores_for_all UserId=bergse00(27485) GroupId=nimwegen(602) MCS_label=N/A Priority=12591 Nice=0 Account=nimwegen QOS=6hours JobState=FAILED Reason=NonZeroExitCode Dependency=(null) Requeue=0 Restarts=0 BatchFlag=1 Reboot=0 ExitCode=1:0 DerivedExitCode=0:0 RunTime=00:00:01 TimeLimit=06:00:00 TimeMin=N/A SubmitTime=2018-09-26T13:51:07 EligibleTime=2018-09-26T13:51:07 StartTime=2018-09-26T13:51:07 EndTime=2018-09-26T13:51:08 Deadline=N/A PreemptTime=None SuspendTime=None SecsPreSuspend=0 Partition=scicore AllocNode:Sid=login10:11043 ReqNodeList=(null) ExcNodeList=(null) NodeList=shi31 BatchHost=shi31 NumNodes=1 NumCPUs=1 NumTasks=1 CPUs/Task=1 ReqB:S:C:T=0:0:*:* TRES=cpu=1,mem=8G,node=1 Socks/Node=* NtasksPerN:B:S:C=0:0:*:* CoreSpec=*   Nodes=shi31 CPU_IDs=3 Mem=8192 GRES_IDX= MinCPUsNode=1 MinMemoryNode=8G MinTmpDiskNode=0 Features=(null) DelayBoot=00:00:00 Gres=(null) Reservation=(null) OverSubscribe=OK Contiguous=0 Licenses=(null) Network=(null) Command=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/command.sh WorkDir=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/Peaks2EnrichmentSubCrunch/components/EnrichmentScores StdErr=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stderr StdIn=/dev/null StdOut=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stdout Power='                                                  

    # Create a dictionary for every job, containig the current status   
    job_status_dict = {}
    for i in range(NUMBER_OF_JOBS):
        job_status_dict[i+1] = None

    # wait for jobs to complete
    inconsistent_status = 0
    for x in range(1200):
        time.sleep(30)
        p = subprocess.Popen(['scontrol', 'show', 'jobid', '-d', '-o', job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        job_state, err = p.communicate()

        js = job_state.strip().split('\n')

        # loop over array jobs
        for i in range(len(js)):
            array_task_info = js[i].split()
            for info in array_task_info:
                if 'ArrayTaskId' in info:
                    try:
                        array_task_id = [int(info.strip('ArrayTaskId='))]
                    except ValueError:
                        array_range = map(int, info.strip('ArrayTaskId=').split('-'))
                        array_task_id = list(np.arange(array_range[0], array_range[1]+1, 1))
                if 'JobState' in info:
                    array_task_state = info.strip('JobState=')

            for ai in array_task_id:
                job_status_dict[int(ai)] = array_task_state


        n_pending = job_status_dict.values().count('PENDING')
        n_running = job_status_dict.values().count('RUNNING')
        n_failed = job_status_dict.values().count('FAILED')
        n_completed = job_status_dict.values().count('COMPLETED') + job_status_dict.values().count('COMPLETING')

        if not n_pending + n_running + n_failed + n_completed == NUMBER_OF_JOBS:
            if inconsistent_status > 2:
                print 'Something is going wrong with monitoring array job.'
                print job_status_dict
                return -1
            else: # if job status do not add up, run scontrol once or twice more, as it might be faulty. 
                inconsistent_status += 1
                continue
        else:
            inconsistent_status = 0

        if n_running == 0:
            if n_failed > 0:
                print 'At least one array job has failed.'
                print job_status_dict
                return -1
            elif n_completed == NUMBER_OF_JOBS:
                print 'Array job done!'
                job_finished = True
                time.sleep(30)
                break

    if not job_finished:
        print "Job took too long. Broke before being done."    
        return -1
    else:
        return 0

def run_slurm_bulk_job(slurm_submission_script, NUMBER_OF_JOBS):

    # run job
    proc = subprocess.check_output(['sbatch', slurm_submission_script])
    job_id = proc.split()[-1]
    job_finished = False

    print 'Job submitted with ID %i' %int(job_id)

    return_state = check_array_job_status(NUMBER_OF_JOBS, job_id)


def combineFiles(peakstats_root, tfbsstats_root, count, peakstats, TFBSstats):

    op = open(peakstats, 'w')
    op.write('#chrom\tstart\tend\tpeakID\tzscore\tquality\tsummed_posterior\n')
    ot = open(TFBSstats, 'w')
    ot.write('#chrom\tstart\tend\tpeakID\tdistance\tposterior\tTFBS_coverage\n')

    for i in arange(1, count+1, 1):
        for pl in open(peakstats_root + '.%i' %i):
            op.write(pl)

        for tl in open(tfbsstats_root + '.%i' %i):
            ot.write(tl)

    op.close()
    ot.close()


def execute(cf):
    """
    This component gives true regions (determined by a posterior cut-off on TFBS).
    It produces some plots: 
        -histogram of region posteriors (one with summed posteriors and one with maximum TFBS posterior per region)
        -plots peak coverage (from RegionCoverage) plots with TFBSs (above 0.5 posterior cut-off)
    """

    ##Ports and parameters
    regions = cf.get_input("regions") #sequences of candidate regions
    train_set = cf.get_input("train_set") #train_set and bg_train_set are used to estimate a prior for the WM that is then used to predict sites on the regions.
    bg_train_set = cf.get_input("bg_train_set")
    regcov_dir = cf.get_input("RegCov_dir")
    WM = cf.get_input("WM") 
    statsfile = cf.get_input("statsfile")

    peakstats = cf.get_output("peakstats")
    TFBSstats = cf.get_output("TFBSstats")
    interm = cf.get_output("intermediate")
    log_file = cf.get_output("log_file")


    plotdir = os.path.join(os.path.split(interm)[0], 'peak_plots') #cf.get_output("peak_plots")

    genome = cf.get_parameter("genome", "string")
    minpost = cf.get_parameter("minposterior", "float")
    motevo_path = cf.get_parameter("motevo_path", "string")
    markovorder = cf.get_parameter("markovorder", "int")
    read_files = cf.get_parameter("read_files", "string")
    fpj = cf.get_parameter("files_per_job", "float")
    queue_name_motevo = cf.get_parameter("queue_name_motevo", "string")
    queue_name = cf.get_parameter("queue_name", "string")
    environment = cf.get_parameter("environment", "string")

    instance_name = cf.get_metadata("instanceName")

    T1 = datetime.datetime.now()

    # Creatign a plot for every region with the predicted TFBS can result in many files (>1GB) and it also takes time to create, archive and remove again.
    # So I just create a flag to stop producing them
    # There is also the same flag in combinePosteriors script. To switch on plotting, also switch the flag there!
    do_plots = False

    ##Main function
    os.mkdir(interm)
    if do_plots:
        os.mkdir(plotdir)

    fraglen = findFraglen(read_files)

    #get parameter file and predicted sites for best WM
    pickled_sd = os.path.join(interm, 'sitesDict')
    pickled_ids = os.path.join(interm, 'IDstats')
    pickled_idc = os.path.join(interm, 'IDcoords')

    retval = runMotevo(motevo_path, regions, train_set, bg_train_set, WM, interm, pickled_sd, pickled_ids, pickled_idc, statsfile, regcov_dir, instance_name, queue_name_motevo, genome, markovorder, environment)
    if not retval:
        return 1

    T2 = datetime.datetime.now()

    #create Plots and build peakstats and TFBSstats files
    regcov_root, count = splitRegCovs(regcov_dir, interm, fpj)

    peakstats_root = os.path.join(interm, 'peakstatsfile')
    tfbsstats_root = os.path.join(interm, 'tfbsstatsfile')

    slurm_submission_script = create_slurm_submission_script(regcov_root, pickled_sd, pickled_ids, pickled_idc, plotdir, fraglen, minpost, peakstats_root, tfbsstats_root, count, instance_name, queue_name, environment)
    return_status = run_slurm_bulk_job(slurm_submission_script, count)

    combineFiles(peakstats_root, tfbsstats_root, count, peakstats, TFBSstats)

    if do_plots:
        # make archive:
        pwd = os.getcwd()
        os.system('cd %s && tar -czf ../%s.tar.gz . && cd %s' %(plotdir, os.path.split(plotdir)[1], pwd))
        os.system('rm -r %s' %plotdir)


    # clean up: remove pickled dictionaries and file chunks
    os.system('rm %s %s %s' %(pickled_sd, pickled_ids, pickled_idc))
    for fname in os.listdir(interm):
        f = os.path.join(interm, fname)
        if fname.startswith('peakstatsfile') or fname.startswith('tfbsstatsfile') or fname.startswith('regcovfile'):
            os.system('rm %s' %f)


    T3 = datetime.datetime.now()


    timetext = '\n'.join(['Running time:',
                          '\t-Predicting sites on given regions and loading data into dictionaries: %s' %(T2-T1),
                          '\t-Computing peak posteriors: %s' %(T3-T2),
                          ])


    lf = open(log_file, 'w')
    lf.write(timetext)
    lf.close()

    print 'Running time: %s' %str(T3-T1)

    return 0


component_skeleton.main.main(execute)
                                                                 
