#!/usr/bin/env python

import component_skeleton.main
import os, sys
import subprocess
import re 
from math import ceil
import numpy as np
import matplotlib.pyplot as plt
import datetime, time

NUMBER_OF_COMPUTATION_NODES = 240


def createScratchDirectory(outfile):
    scratch_dir = os.path.join(os.path.dirname(outfile), "scratch")
    os.system('mkdir %s' % scratch_dir)
    return scratch_dir


def listOfAllWMs(WMdirectories, scratchDir):
    WmFilename = os.path.join(scratchDir, 'WMs')
    WMlist = []
    with open(WmFilename, 'w') as WMs:    
        for WMdirectory in WMdirectories:
            if not WMdirectory:
                continue
            for wm in os.listdir(u'%s' % WMdirectory):
                WMs.write( '%s\n' % os.path.join(WMdirectory, wm) )
                WMlist.append( os.path.join(WMdirectory, wm) )
    return WmFilename, WMlist


def createSequencePool(InputSequences, DecoySequences, scratchDir, resFile):
    resFilename = os.path.join(scratchDir, resFile)
    cmd = ' '.join(['cat',
                    InputSequences,
                    DecoySequences])
    proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
    with open(resFilename, 'w') as outf:
        for line in proc.stdout:
            outf.write(line)
    return resFilename


def create_slurm_submission_script(TrainingPool, TestPool, WMs, genome, motevo_path, environment, mem_limit, time_limit, scratch_dir, job_name, NUMBER_OF_MOTIFS_PER_JOB=1, NUMBER_OF_JOBS=1):

    stderr_file = os.path.join(scratch_dir, 'stderr')
    stdout_file = os.path.join(scratch_dir, 'stdout')

    prog = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'calculate_enrichment_score.py')
    jobFileContent = '\n'.join([
        '#!/bin/bash',
        '#SBATCH -n 1',
        '#SBATCH --mem=%s' %mem_limit,
        '#SBATCH --time=%s' %time_limit,
        '#SBATCH -e %s' %stderr_file,
        '#SBATCH -o %s' %stdout_file,
        '#SBATCH --job-name=%s' %job_name,
        '#SBATCH --array=1-%i' %NUMBER_OF_JOBS,
        '#SBATCH --workdir=%s' %os.path.dirname(os.path.realpath(__file__)),
        'source %s' %environment,
        'WMFILE=%s' % WMs,
        'WM=$(sed -n -e "$SLURM_ARRAY_TASK_ID p" $WMFILE)',
        'python {prog} -w "$WM" \\'.format(prog=prog),
        '   -t \'{trainseq}\' \\'.format(trainseq=TrainingPool),
        '   -s \'{testseq}\' \\'.format(testseq=TestPool),
        '   -o \'{scratch}\' -g \'{genome}\' \\'.format(scratch=scratch_dir, genome=genome),
        '   -m \'{motevo}\''.format(motevo=motevo_path)])

    shellFilename = os.path.join(scratch_dir, 'command.sh')
    with open(shellFilename, 'w') as outf:
        outf.write(jobFileContent)

    return shellFilename

def check_array_job_status(NUMBER_OF_JOBS, job_id):

    # scontrol returns something like that for every array job:
    # 'JobId=55226883 ArrayJobId=55226883 ArrayTaskId=1 JobName=enrichmentScores_for_all UserId=bergse00(27485) GroupId=nimwegen(602) MCS_label=N/A Priority=12591 Nice=0 Account=nimwegen QOS=6hours JobState=FAILED Reason=NonZeroExitCode Dependency=(null) Requeue=0 Restarts=0 BatchFlag=1 Reboot=0 ExitCode=1:0 DerivedExitCode=0:0 RunTime=00:00:01 TimeLimit=06:00:00 TimeMin=N/A SubmitTime=2018-09-26T13:51:07 EligibleTime=2018-09-26T13:51:07 StartTime=2018-09-26T13:51:07 EndTime=2018-09-26T13:51:08 Deadline=N/A PreemptTime=None SuspendTime=None SecsPreSuspend=0 Partition=scicore AllocNode:Sid=login10:11043 ReqNodeList=(null) ExcNodeList=(null) NodeList=shi31 BatchHost=shi31 NumNodes=1 NumCPUs=1 NumTasks=1 CPUs/Task=1 ReqB:S:C:T=0:0:*:* TRES=cpu=1,mem=8G,node=1 Socks/Node=* NtasksPerN:B:S:C=0:0:*:* CoreSpec=*   Nodes=shi31 CPU_IDs=3 Mem=8192 GRES_IDX= MinCPUsNode=1 MinMemoryNode=8G MinTmpDiskNode=0 Features=(null) DelayBoot=00:00:00 Gres=(null) Reservation=(null) OverSubscribe=OK Contiguous=0 Licenses=(null) Network=(null) Command=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/command.sh WorkDir=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/Peaks2EnrichmentSubCrunch/components/EnrichmentScores StdErr=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stderr StdIn=/dev/null StdOut=/scicore/home/nimwegen/bergse00/Projects/GenomeResearchRevision/peakcaller_comparison/testrun/Peak2ES_test/OUTPUT/enrichmentScores_for_all/scratch/stdout Power='                                                  

    # Create a dictionary for every job, containig the current status   
    job_status_dict = {}
    for i in range(NUMBER_OF_JOBS):
        job_status_dict[i+1] = None

    # wait for jobs to complete
    inconsistent_status = 0
    for x in range(1200):
        time.sleep(30)
        p = subprocess.Popen(['scontrol', 'show', 'jobid', '-d', '-o', job_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        job_state, err = p.communicate()

        js = job_state.strip().split('\n')

        # loop over array jobs
        for i in range(len(js)):
            array_task_info = js[i].split()
            for info in array_task_info:
                if 'ArrayTaskId' in info:
                    try:
                        array_task_id = [int(info.strip('ArrayTaskId='))]
                    except ValueError:
                        array_range = map(int, info.strip('ArrayTaskId=').split('-'))
                        array_task_id = list(np.arange(array_range[0], array_range[1]+1, 1))
                if 'JobState' in info:
                    array_task_state = info.strip('JobState=')

            for ai in array_task_id:
                job_status_dict[int(ai)] = array_task_state


        n_pending = job_status_dict.values().count('PENDING')
        n_running = job_status_dict.values().count('RUNNING')
        n_failed = job_status_dict.values().count('FAILED')
        n_completed = job_status_dict.values().count('COMPLETED') + job_status_dict.values().count('COMPLETING')

        if not n_pending + n_running + n_failed + n_completed == NUMBER_OF_JOBS:
            if inconsistent_status > 2:
                print 'Something is going wrong with monitoring array job.'
                print job_status_dict
                return -1
            else: # if job status do not add up, run scontrol once or twice more, as it might be faulty. 
                inconsistent_status += 1
                continue
        else:
            inconsistent_status = 0

        if n_running == 0:
            if n_failed > 0:
                print 'At least one array job has failed.'
                print job_status_dict
                return -1
            elif n_completed == NUMBER_OF_JOBS:
                print 'Array job done!'
                job_finished = True
                time.sleep(30)
                break

    if not job_finished:
        print "Job took too long. Broke before being done."    
        return -1
    else:
        return 0

def run_slurm_bulk_job(slurm_submission_script, NUMBER_OF_JOBS):

    # run job
    proc = subprocess.check_output(['sbatch', slurm_submission_script])
    job_id = proc.split()[-1]
    job_finished = False

    print 'Job submitted with ID %i' %int(job_id)

    return_state = check_array_job_status(NUMBER_OF_JOBS, job_id)


def concatenateResults(scratchDir, resFilename, col):
    sortedWMs = []
    files = [os.path.join(scratchDir, f) for f in os.listdir(scratchDir) if re.search('\.results$', f)]
    resFileUnsorted = resFilename + '.unsorted'
    with open(resFileUnsorted, 'w') as outf:
        for a_file in files:
            with open(a_file) as inf:
                outf.write(inf.readline())
    ## adding a header line that can be used by anduril
    with open(resFilename, 'w') as outf:
        outf.write('WM_path\tenrichment_score\tstdev\tLL_ratio\tbeta\tbg_prior\tmean_bg_sites\n')
    ## sorting according to the enrichment score column
    cmd = 'sort -gr -k %d %s >> %s' % (col, resFileUnsorted, resFilename)
    os.system(cmd)
    with open(resFilename) as outf:
        header = True
        for line in outf:
            if header:
                header = False
                continue
            sortedWMs.append( line.split() )

    for a_file in files:
        os.system( "rm '%s'" % a_file )
    os.system( "rm '%s'" % resFileUnsorted )

    return sortedWMs


def cleaningUpTmpFiles(scratchDir):
    cmd = "rm -fr '%s'" % scratchDir
    os.system(cmd)
    return 0


def createWMcombinedFile(topWM, restWMs, scratchDir):
    WmFilename = os.path.join(scratchDir, 'WMs')
    with open(WmFilename, 'w') as outf:
        for a_wm in restWMs:
            outf.write('%s %s\n' % (topWM, a_wm[0]))
    return WmFilename


def findTopWMinWMs(WMs, topWM):
    query = topWM.split(' ')[-1]
    for index in xrange(len(WMs)):
        if WMs[index][0] == query:
            return index
    return -1
    

def combinedMotifs(trainingPool, testPool, WMs, job_name, scratch_dir, GENOME, NUMBER_OF_MOTIFS_PER_JOB, NUMBER_OF_JOBS, motevo_path, environment, mem_limit, time_limit):
    index = 1
    numberOfForegroundSeq = len([line for line in open(testPool) if re.search('_reg\d+', line)])
    convergence_criterion = np.log(10.)
    print 'Convergence criterion for finding complementary motifs is %f (at least 10 fold increase in log-likelihood)' %convergence_criterion
    topWM = WMs[0][0]
    topEnrichmentScoreFirstRound = float(WMs[0][1])
    top_LL_FirstRound = float(WMs[0][3])
    WMs.remove(WMs[0])
    topEnrichmentScoreSecondRound = 0.
    enrichmentScoresEachRound = [topEnrichmentScoreFirstRound]
    LL_RatioEachRound = [top_LL_FirstRound]
    while True:
        WmFile = createWMcombinedFile(topWM, WMs, scratch_dir)
    
        slurm_submission_script = create_slurm_submission_script(trainingPool, testPool, WmFile, GENOME, motevo_path, environment, mem_limit, time_limit, scratch_dir, job_name, NUMBER_OF_MOTIFS_PER_JOB, NUMBER_OF_JOBS)
        return_state = run_slurm_bulk_job(slurm_submission_script, NUMBER_OF_JOBS)
        
        ## make the last result file, sorted by the average enrichment score
        outfile = os.path.join(os.path.dirname(scratch_dir), 'EnrichmentScores_%d' % (index+1))
        sortedWMs = concatenateResults(scratch_dir, outfile, index+2)
        topEnrichmentScoreSecondRound = float(sortedWMs[0][index+1])
        top_LL_SecondRound = float(sortedWMs[0][index+3])
        if (top_LL_SecondRound - top_LL_FirstRound) < convergence_criterion: # convergence criterion: exp( nr_of_fg_seqs * enrichment_score_diff ) < 100 --> enrichment_score diff should increase at least by log(100)/nr_of_fg_seqs (0.009 for 00 fg seqs)
            break
        topWM = ' '.join(sortedWMs[0][:(index+1)])
        removeIndex = findTopWMinWMs(WMs, topWM)
        if not removeIndex == -1:
            WMs.remove(WMs[removeIndex])
        else:
            raise Exception
        enrichmentScoresEachRound.append(topEnrichmentScoreSecondRound)
        LL_RatioEachRound.append(top_LL_SecondRound)
        topEnrichmentScoreFirstRound = topEnrichmentScoreSecondRound
        top_LL_FirstRound = top_LL_SecondRound
        index += 1
        if index > 9: # now maximum size of motif set should be 12, I think
            break
    return topWM, enrichmentScoresEachRound, LL_RatioEachRound
        

def execute(cf):
    """
    It receives two set of motifs: the de novo motifs, and the databse motifs.
    For each of these motifs, it fits the parameters background prior and beta.
    Using the fitted parameters, it calculates the average and sd enrichment scores
    for each of the motifs.
    Final result is a file that its line holds the value of prior, beta, mean
    enrichment score, and standard deviation enrichment score for each motif. 
    """
    TrainingInputSequences = cf.get_input("TrainingSequences")
    TrainingDecoySequences = cf.get_input("TrainingDecoySequences")
    TestSequences = cf.get_input("TestSequences")
    TestDecoySequences = cf.get_input("TestDecoySequences")
    DenovoWMs = cf.get_input("DenovoWMs")

    DatabaseWMs = cf.get_parameter("DatabaseWMs")
    GENOME = cf.get_parameter('genome', 'string')
    CombinedMotifs = cf.get_parameter('CombinedMotifs', 'boolean')
    top_wms = cf.get_parameter('top_wms', 'int')
    queue_type = cf.get_parameter('queue_type', 'string')
    motevo_path = cf.get_parameter('motevo_path', 'string')
    environment = cf.get_parameter('environment', 'string')

    outfile = cf.get_output("EnrichmentScores")
    outfile_tops = cf.get_output("EnrichmentScores_tops") #This file contains the top WMs of the single WM run. The number of top WMs is controlled by the top_wms parameter.
    
    T1 = datetime.datetime.now()

    # membycore=8G, runtime=6:00:00, n_nodes=1                                                                                         
    qspecs = queue_type.split(',')
    mem_limit = qspecs[0].split('=')[1]
    time_limit = qspecs[1].split('=')[1]
    n_nodes = qspecs[2].split('=')[1]

    print "Calculating Enrichment Scores for: "
    print DenovoWMs
    print DatabaseWMs
    ## The scratch directory serves as a temporary space for holding files
    scratchDir = createScratchDirectory(outfile)
    ## create a file that lists all the input WMs    
    WmFile, WMs = listOfAllWMs([DenovoWMs, DatabaseWMs], scratchDir)    
    print "Enrichment Scores: There are in total %d WMs" % len(WMs)
    ## create the training pool that contains both real and decoy (shuffled) sequences

    trainingPool = createSequencePool(TrainingInputSequences, TrainingDecoySequences, scratchDir, 'trainingPool')
    testPool = createSequencePool(TestSequences, TestDecoySequences, scratchDir, 'testPool')

    ## createJobTemplate for the array job (runs for every motif the fitting and enrichment score program)    
    job_name = os.path.basename(os.path.dirname(outfile))

    NUMBER_OF_MOTIFS_PER_JOB = max(int(ceil(len(WMs) / NUMBER_OF_COMPUTATION_NODES)), 2) #minimum of 4 motifs per job.
    NUMBER_OF_JOBS = len(WMs)

    slurm_submission_script = create_slurm_submission_script(trainingPool, testPool, WmFile, GENOME, motevo_path, environment, mem_limit, time_limit, scratchDir, job_name, NUMBER_OF_MOTIFS_PER_JOB, NUMBER_OF_JOBS)
    return_state = run_slurm_bulk_job(slurm_submission_script, NUMBER_OF_JOBS)
    
    ## make the last result file, sorted by the average enrichment score
    # sortedWMs = [wm.split() for wm in \
    #              open('/import/bc2/home/nimwegen/omidi/Projects/ChIPseq-pipeline/example2/OUTPUT/IRF3_FgBg-enrichmentScores_all_motifs/EnrichmentScores')][:50]
    # scratchDir = '/import/bc2/home/nimwegen/omidi/Projects/ChIPseq-pipeline/example2/OUTPUT/IRF3_FgBg-enrichmentScores_all_motifs/scratch'
    sortedWMs = concatenateResults(scratchDir, outfile, 2)
    with open(outfile_tops, 'w') as outf_tops:
        with open(outfile) as outf:
            i = 0
            for line in outf:
                if i > top_wms:
                    break
                outf_tops.write(line)
                i += 1
 
    ## cleaning up the scratch directory
    # cleaningUpTmpFiles(scratchDir)
    if CombinedMotifs:
        topWM, enrichmentScoresEachRound, LL_ratioEachRound = combinedMotifs(trainingPool, testPool, sortedWMs, job_name, scratchDir, GENOME, NUMBER_OF_MOTIFS_PER_JOB, NUMBER_OF_JOBS, motevo_path, environment, mem_limit, time_limit)
        topWM_list = topWM.split()
        with open(outfile, 'w') as outf:
            outf.write('WM_path\tenrichment_score\tLL_ratio\n')
            for i in range(len(topWM_list)):
                outf.write(topWM_list[i] + '\t' + str(enrichmentScoresEachRound[i]) + '\t' + str(LL_ratioEachRound[i]) + '\n')

        plt.plot(range(len(topWM_list)+1), [1] + enrichmentScoresEachRound, 'r-')
        plt.plot(range(len(topWM_list)+1), [1] + enrichmentScoresEachRound, 'ko')
        plt.ylabel("Enrichment Score")
        plt.xticks(range(len(topWM_list)+1), [''] + [os.path.split(n)[1] for n in topWM_list], rotation=90)
        plt.tight_layout()
        plt.savefig(outfile + '.png')

    T2 = datetime.datetime.now()
    print 'Running time: %s' %str(T2-T1)

    return 0


component_skeleton.main.main(execute)
