#!/usr/bin/env python

import os, re
from string import *
from pylab import *
import component_skeleton.main
import datetime, time
import subprocess

def findFraglen(readfiles):

    fraglens = []
    rlist = readfiles.split()
    for rf in rlist:
        ##get Fragment length from the log of FragmentLength component
        fraglenRoot = os.path.join(os.path.split(rf)[0], '..', 'intermediate')
        interdir = os.listdir(fraglenRoot)
        for f in interdir:
            m = re.match('\S*.res$', os.path.join(fraglenRoot, f))
            if m:
                fraglenF = m.group()
                break
        fraglen = int(float(open(fraglenF).read().strip().split()[-1]))
        fraglens.append(fraglen)
        print '\nFound: %s with fragment length: %s\n' %(rf,fraglen)
            
    meanFraglen = mean(fraglens)
    maxFraglen= max(fraglens)
    print '\nFragment Length: %s' %fraglens

    return meanFraglen


def createJobFiles(indir, interm, fpj):
    """
    This function creates files (each containing fpj coverage file paths) and naming them file.SGE_TASK_ID (=1-totfiles/fpj) 
    """

    count = 0 #counts the number of files used
    i = 0

    fname = 'covfiles'

    for f in os.listdir(indir):
        if i >= fpj:
            i = 0
            fh.close()
        if i == 0:
            count += 1
            outfile = os.path.join(interm, fname + '.%i' %count)
            fh = open(outfile, 'w')

        name = os.path.join(indir, f)
        fh.write(name + '\n')
        i += 1

    return count, os.path.join(interm, fname)



def create_slurm_submission_script(count, infileroot, outfileroot, plotdir, fraglen, order, width, instance_name, environment, mem_limit, time_limit):

    scratch_dir = os.path.split(plotdir)[0]
    stderr_file = os.path.join(scratch_dir, 'stderr')
    stdout_file = os.path.join(scratch_dir, 'stdout')

    prog = './MMEM2.py'

    jobFileContent = '\n'.join([
        '#!/bin/bash',
        '#SBATCH -n 1',
        '#SBATCH --mem=%s' %mem_limit,
        '#SBATCH --time=%s' %time_limit,
        '#SBATCH -e %s' %stderr_file,
        '#SBATCH -o %s' %stdout_file,
        '#SBATCH --job-name=%s' %instance_name,
        '#SBATCH --array=1-%i' %count,
        '#SBATCH --workdir=%s' %os.path.dirname(os.path.realpath(__file__)),
        'source %s' %environment,
        'python %s %s %s %s %s %s %s' %(prog, infileroot, outfileroot, plotdir, str(fraglen), str(order), str(width))])

    shellFilename = os.path.join(scratch_dir, 'command.sh')
    with open(shellFilename, 'w') as outf:
        outf.write(jobFileContent)

    return shellFilename


def check_array_job_status(NUMBER_OF_JOBS, job_id):

    # scontrol returns something like that for every array job:
    # 'JobId=55226883 ArrayJobId=55226883 ArrayTaskId=1 JobName=enrichmentScores_for_all UserId=bergse00(27485) GroupId=nimwegen(602) MCS_label=N/A Priority=12591 Nice=0 Account=nimwegen QOS=6hours JobState=FAILED Reason=NonZeroExitCode Dependency=(null) Requeue=0 Restarts=0 BatchFlag=1 Reboot=0 ExitCode=1:0 DerivedExitCode=0:0 RunTime=00:00:01 TimeLimit=06:00:00 TimeMin=N/A SubmitTime=2018-09-26T13:51:07 EligibleTime=2018-09-26T13:51:07 StartTime=2018-09-26T13:51:07 EndTime=2018-09-26T13:51:08 Deadline=N/A PreemptTime=None SuspendTime=None SecsPreSuspend=0 Partition=scicore AllocNode:Sid=login10:11043 ReqNodeList=(null) ExcNodeList=(null) NodeList=shi31 BatchHost=shi31 NumNodes=1 NumCPUs=1 NumTasks=1 CPUs/Task=1 ReqB:S:C:T=0:0:*:* TRES=cpu=1,mem=8G,node=1 Socks/Node=* NtasksPerN:B:S:C=0:0:*:* CoreSpec=*   Nodes=shi31 CPU_IDs=3 Mem=8192 GRES_IDX= MinCPUsNode=1 MinMemoryNode=8G MinTmpDiskNode=0 Features=(null) DelayBoot=00:00:00 Gres=(null) Reservation=(null) OverSubscribe=OK Contiguous=0 Licenses=(null) Network=(null) Command=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/command.sh WorkDir=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/Peaks2EnrichmentSubCrunch/components/EnrichmentScores StdErr=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stderr StdIn=/dev/null StdOut=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stdout Power='

    # Create a dictionary for every job, containig the current status
    job_status_dict = {}
    for i in range(NUMBER_OF_JOBS):
        job_status_dict[i+1] = None

    # wait for jobs to complete
    inconsistent_status = 0
    for x in range(1200):
        time.sleep(30)
        p = subprocess.Popen(['scontrol', 'show', 'jobid', '-d', '-o', job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        job_state, err = p.communicate()

        js = job_state.strip().split('\n')

        # loop over array jobs
        for i in range(len(js)):
            array_task_info = js[i].split()
            for info in array_task_info:
                if 'ArrayTaskId' in info:
                    try:
                        array_task_id = [int(info.strip('ArrayTaskId='))]
                    except ValueError:
                        array_range = map(int, info.strip('ArrayTaskId=').split('-'))
                        array_task_id = list(np.arange(array_range[0], array_range[1]+1, 1))
                if 'JobState' in info:
                    array_task_state = info.strip('JobState=')

            for ai in array_task_id:
                job_status_dict[int(ai)] = array_task_state


        n_pending = job_status_dict.values().count('PENDING')
        n_running = job_status_dict.values().count('RUNNING')
        n_failed = job_status_dict.values().count('FAILED')
        n_completed = job_status_dict.values().count('COMPLETED') + job_status_dict.values().count('COMPLETING')

        if not n_pending + n_running + n_failed + n_completed == NUMBER_OF_JOBS:
            if inconsistent_status > 2:
                print 'Something is going wrong with monitoring array job.'
                print job_status_dict
                return -1
            else: # if job status do not add up, run scontrol once or twice more, as it might be faulty. 
                inconsistent_status += 1
                continue
        else:
            inconsistent_status = 0

        if n_running == 0:
            if n_failed > 0:
                print 'At least one array job has failed.'
                print job_status_dict
                return -1
            elif n_completed == NUMBER_OF_JOBS:
                print 'Array job done!'
                job_finished = True
                time.sleep(30)
                break

    if not job_finished:
        print "Job took too long. Broke before being done."    
        return -1
    else:
        return 0


def run_slurm_bulk_job(slurm_submission_script, NUMBER_OF_JOBS):

    # run job
    proc = subprocess.check_output(['sbatch', slurm_submission_script])
    job_id = proc.split()[-1]
    job_finished = False

    print 'Job submitted with ID %i' %int(job_id)

    return_status = check_array_job_status(NUMBER_OF_JOBS, job_id)


def combineFiles(count, fileroot, interm, allpeaks, statsfile, rmsd_co):

    tmpfile = os.path.join(interm, 'tmpfile')
    badqualpeaksfile = os.path.join(os.path.split(interm)[0], 'badqualpeaks')

    o = open(tmpfile, 'w')
    s = open(statsfile, 'w')
    bqf = open(badqualpeaksfile, 'w')

    for i in arange(1, count+1, 1):
        for line in open(fileroot+'.%i' %i):
            t = line.strip().split()
            if float(t[5]) > rmsd_co:
                bqf.write('%s\t%s\t%s\t%s\t%s\t+\n' %(t[0], t[1], t[2], t[3], t[4]))
            else:
                o.write('%s\t%s\t%s\t%s\t%s\t+\n' %(t[0], t[1], t[2], t[3], t[4]))
            s.write('%s\t%s\t%s\t%s\t%s\t%s\n' %(t[3], t[4], t[5], t[6], t[7], t[8]))

    o.close()
    s.close()
    bqf.close()

    os.system('sort -k5gr %s > %s' %(tmpfile, allpeaks))
    os.system('rm %s' %tmpfile)

    return badqualpeaksfile


def execute(cf):

    indir = cf.get_input("in_dir") #covfiles directory

    logfile = cf.get_output("log_file")
    outfile = cf.get_output("outfile") #contains 1000 top peaks (by height). Peaks above rmsd_co were filtered out. sorted by height
    allpeaks = cf.get_output("allpeaks") #contains all peaks that made rmsd_co. sorted by height
    statsfile = cf.get_output("peakstats") #contains statistics of all found peaks. Ones that made rmsd_co and the ones that didn't

    interm = os.path.join(os.path.split(logfile)[0], 'intermediate') 
    plotdir = os.path.join(os.path.split(logfile)[0], 'peak_plots') 

    #plots:
    height_sigma_scatter = cf.get_output("height_sigma_scatter")
    height_rmsd_scatter = cf.get_output("height_rmsd_scatter")
    sigma_rmsd_scatter = cf.get_output("sigma_rmsd_scatter")
    height_revcum = cf.get_output("height_revcum")
    rmsd_hist = cf.get_output("rmsd_hist")
    sigma_hist = cf.get_output("sigma_hist")
    height_hist = cf.get_output("height_hist")

    #parameters:
    readfiles = cf.get_parameter("FGfiles_string", "string")
    fraglen = cf.get_parameter("FragmentLength", "int")
    fpj = cf.get_parameter("files_per_job", "int")
    order = cf.get_parameter("order", "int")
    toppeaks = cf.get_parameter("topPeaks", "int")
    width = cf.get_parameter("widthFactor", "float") #widthFactor * sigma: this is added to each side of mu to define a peak.
    rmsd_co = cf.get_parameter("RMSD_cutoff", "float") #RMSD cut-off
    queue_name = cf.get_parameter("queue_name", "string")
    environment = cf.get_parameter("environment", "string")

    instance_name = cf.get_metadata("instanceName")

    # membycore=8G, runtime=6:00:00, n_nodes=1                                                                                         
    qspecs = queue_name.split(',')
    mem_limit = qspecs[0].split('=')[1]
    time_limit = qspecs[1].split('=')[1]
    n_nodes = qspecs[2].split('=')[1]

    os.mkdir(interm)
    os.mkdir(plotdir)

    T1 = datetime.datetime.now()

    ##try to find fragment lengths with readfiles
    if fraglen == -1 :
        if readfiles == '':
            print '\nError: No length cut-offs or read files to get maximum length cut-off are given.\n'
            return 1
        else:
            fraglen = findFraglen(readfiles)

    count, filesfileroot = createJobFiles(indir, interm, fpj)

    T2 = datetime.datetime.now()

    statsfileroot = os.path.join(interm, 'stats')


    slurm_submission_script = create_slurm_submission_script(count, filesfileroot, statsfileroot, plotdir, fraglen, order, width, instance_name, environment, mem_limit, time_limit)
    run_slurm_bulk_job(slurm_submission_script, count)

    T3 = datetime.datetime.now()

    badpeaksfile = combineFiles(count, statsfileroot, interm, allpeaks, statsfile, rmsd_co)

    bpnum = 0
    for i in open(badpeaksfile):
        bpnum += 1


    p = open(outfile, 'w')
    for i, line in enumerate(open(allpeaks)):
        t = line.strip().split()
        lastheight = float(t[4])
        if i >= toppeaks:
            break
        else:
            p.write(line)

    p.close()


    #plot stuff
    a = loadtxt(statsfile, usecols=[1,2,4])
    sigmas = a.T[2]
    heights = a.T[0] #add pseudocount
    quals = a.T[1]

    # if there is just one peak called heights, sigmas and quals will be just numbers. Thus test this and make arrays in that case.
    try:
        len(heights)
    except TypeError:
        sigmas = array([sigmas])
        heights = array([heights])
        quals = array([quals])

    figure()
    plot(log10(sorted(heights, reverse=True)), arange(1,len(heights)+1,1), '.', rasterized=True)
    lh = log10(lastheight)
    plot([lh, lh], [0, len(heights)], label='lowest peak used for motif finding: %s' %lastheight)
    xlabel('log10(heights)')
    ylabel('number of peaks with up to peak height')
    legend()
    savefig(height_revcum)
    close()


    figure()
    hist(sigmas, 300, histtype='step')
    title('Peakwidth (sigma) histogram')
    savefig(sigma_hist)
    close()

    figure()
    hist(log10(heights), 300, histtype='step')
    title('Peakheight (log10-space) histogram')
    savefig(height_hist)
    close()


    figure()
    plot(sigmas, heights, '.', rasterized=True)
    xlabel('peak width (sigma)')
    ylabel('peak height')
    savefig(height_sigma_scatter)
    close()

    log10quals = log10(quals)

    figure()
    plot(sigmas, log10quals, '.', rasterized=True)
    xlabel('peak width (sigma)')
    ylabel('log10(peak quality)')
    savefig(sigma_rmsd_scatter)
    close()

    figure()
    plot(heights, log10quals, '.', rasterized=True)
    xlabel('peak height')
    ylabel('log10(peak quality)')
    savefig(height_rmsd_scatter)
    close()

    figure()
    hist(log10quals[isfinite(log10quals)], 300, histtype='step')
    title('log10(Peak quality (RMSD)) histogram')
    savefig(rmsd_hist)
    close()


    text = '\n'.join(['About %i input regions.' %(fpj*count),
                      'Found mean fragment length %.1f. Constrained sigma to %s-%s' %(fraglen, max(1, 0.414*fraglen - 11.5), 0.414*fraglen + 45.5)
                      #'%i peaks were filtered out due to RMSD cut-off of %s' %(bpnum, rmsd_co)
                      #'Running time for:',
                      #'\tFinding fragment length and splitting up files: %s' %str(T2-T1),
                      #'\tFitting mixture models: %s' %str(T3-T2),
                      #'\tOverall: %s' %str(T3-T1)
                      ])

    l=open(logfile, 'w')
    l.write(text)
    l.close()

    # clean up:
    # tar and zip peak_plots
    pwd = os.getcwd()
    os.system('cd %s && tar -czf ../%s.tar.gz . && cd %s' %(plotdir, os.path.split(plotdir)[1], pwd))
    os.system('rm -r %s' %plotdir)

    #remove intermediate dir
    os.system('rm -r %s' %interm)

    T4 = datetime.datetime.now()
    # write time statistics to stderr:
    print 'Running time: %s' %str(T4-T1)


    return 0

component_skeleton.main.main(execute)
