import os
import sys
import re
import yaml
#import glob3
import numpy as np
import pandas as pd
from snakemake.utils import min_version
# Enforce a minimum Snakemake version
min_version("4.7")

onstart:
    print("##########################################\n")
    print("# STARTING PIPELINE\n")
    print("##########‰################################\n")
#    print ("Running ChIP-seq pre-processing workflow for the following samples:\n " + ' \n '.join(map(str, SAMPLES)))

onsuccess:
    print("##########################################\n")
    print("# PIPELINE ENDED SUCCESSFULLY \n")
    print("##########################################\n")

ruleDisplayMessage = """\n\n####### START RULE EXECUTION ##########\n"""

threadsMax = 16

"""
IDR tool from:
https://github.com/nboley/idr#output-file-format

Code for generating pseudo-replicates for IDR pipeline from:
https://hbctraining.github.io/Intro-to-ChIPseq/lessons/07_handling-replicates-idr.html

Using relaxed p-value (0.1) for MACS2 before IDR.
Using two types of input: merged input files (all individuals at specific time-point) and sample-specific input
Loading module to run macs2 - MACS2/2.1.2.1-foss-2016b-Python-2.7.12
"""

################################################################################
######################## PATHS AND DATA ########################################
################################################################################


### setting which input to used: 'merged_input' or 'unique_input'
input_types = ["merged_input", "unique_input"]


###

### FOLDERS ON TIER1 ###
PROJECT_DIR = config["global"]["projectdir"]
DATA_DIR  = PROJECT_DIR + "/" + config["project_structure"]["datadir"]
EXEC_DIR = PROJECT_DIR + "/" + config["project_structure"]["utils"]
ANALYSIS_DIR = PROJECT_DIR + "/" + config["project_structure"]["analysisdir"]

BAM_DIR = DATA_DIR + "/ChIPseq/Alignments/with_indels"
PEAKS_DIR = DATA_DIR + "/ChIPseq/Peaks/with_indels"
LOG_DIR = PROJECT_DIR + "/log/peaks_and_signal/with_indels/IDR"

### FOLDERS ON SCRATCH and TMP ###
SCRATCH_DIR = config["global"]["scratchdir"] + "/peaks_and_signal/IDR"
TEMP_DIR = SCRATCH_DIR + "/tmp"


### TOOLS ###
SAMTOOLS = config["tools"]["by_path"]["samtools"]
#MACS2 = config["tools"]["by_path"]["macs2"]
MACS2_module = config["tools"]["by_module"]["macs2"]
IDR = config["tools"]["by_path"]["IDR"]
Rscript = config["tools"]["by_path"]["Rscript"]["v4.0"]

### Data and parameters ###

# Individuals
individuals_path = config['data']['samples']['individual_ids']
individuals = open(individuals_path, 'r').read().splitlines()
individuals = ["_".join(x.split("_")[0:2]) for x in individuals]
individuals = set(individuals)
print(individuals)

# TFs and time-points
conditions  = ["mef2/68", "mef2/1012", "bin/68", "bin/1012", "ctcf/68", "zld/24", "twi/24"]

# IDR thresholds
IDR_thres = [0.05, 0.01]
# IDR_thres_transf = [-125 * np.log2(x) for x in IDR_thres]
# idr_dict = dict(zip(IDR_thres, IDR_thres_transf))

################################################################################
######################## FUNCTIONS #############################################
################################################################################

sys.path.append(EXEC_DIR)
import utils as utl

def get_input_file(tp, individual, input_type, dir = BAM_DIR):
    if input_type == "unique_input":
        input_path = dir + "/input." + str(tp) + "." + individual + "_1.filtered.rmdup_withUMI.bam"
    elif input_type == "merged_input":
        input_path = dir + "/merged_samples/input." + str(tp) + ".filtered.rmdup_withUMI.bam"
    else:
        raise KeyError("Incorrect input_type: 'merged_input' or 'unique_input' accepted")
    return input_path

################################################################################
######################## MAIN ##################################################
################################################################################

localrules: all, idr_significant_peaks, summarize_and_plot_idr_results_per_condition, aggregate_idr_results

rule all:
    input:
        expand(PEAKS_DIR + "/{input_type}/IDR/{ab_tp}/{individual}_narrowPeak_idr.txt", input_type = input_types, ab_tp = conditions, individual = individuals),
        expand(PEAKS_DIR + "/{input_type}/IDR/{ab_tp}/{individual}_pseudoreplicates_narrowPeak_idr.txt", input_type = input_types, ab_tp = conditions, individual = individuals),
        expand(PEAKS_DIR + "/{input_type}/IDR/{ab_tp}/{individual}_selfpseudoreplicates{rep}_narrowPeak_idr.txt", input_type = input_types, ab_tp = conditions, individual = individuals, rep = [1, 2]),
        expand(PEAKS_DIR + "/{input_type}/Significant_peaks/{ab_tp}/{individual}_IDR_peaks_{idr_thres}.bed", input_type = input_types, ab_tp = conditions, individual = individuals, idr_thres = IDR_thres),
        expand(ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{input_type}/{ab_tp}/idr_significant_peaks.csv", input_type = input_types, ab_tp = conditions),
        expand(ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{input_type}/idr_significant_peaks.csv", input_type = input_types)



########################################################
########### IDR on true replicates #####################
########################################################

wildcard_constraints:
    rep = "1|2",
    input_type = "merged_input|unique_input"


rule merge_input:
    input: expand(BAM_DIR + "/input.{{tp}}.{individual}_1.filtered.rmdup_withUMI.bam", individual = individuals)
    output: BAM_DIR + "/merged_samples/input.{tp}.filtered.rmdup_withUMI.bam"
    threads: 4
    message: "{ruleDisplayMessage}Merging input files {input} ..."
    shell:
        """{SAMTOOLS} merge {output} {input} &&
           {SAMTOOLS} index {output}
        """

rule macs2_callpeak:
    input:
        test = BAM_DIR + "/{ab}.{tp}.{individual}_{rep}.filtered.rmdup_withUMI.bam",
        control = lambda wildcards: get_input_file(wildcards.tp, wildcards.individual, wildcards.input_type)
    output:
        peaks = temp(PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}_peaks.narrowPeak")
    params:
        base_name = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}",
        macs2_params = config["pipelines"]["4_call_peaks"]["macs2_relaxed4IDR"]
    priority: 1
    threads: 1
    message: "{ruleDisplayMessage}Running macs2 analysis for {wildcards.ab} at {wildcards.tp}; {wildcards.individual} ..."
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}_{rep}.macs2.log"
    shell:
        """
        module load {MACS2_module};
        macs2 callpeak -t  {input.test} -c {input.control} --name {params.base_name} {params.macs2_params} 2> {log}
        """


rule idr:
    input:
        peaks_1 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_1_peaks.narrowPeak",
        peaks_2 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_2_peaks.narrowPeak"
    output: PEAKS_DIR + "/{input_type}/IDR/{ab}/{tp}/{individual}_narrowPeak_idr.txt"
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}.idr.log"
    threads: 1
	shell:
		"""
        sort -k8,8nr {input.peaks_1} > {input.peaks_1}.temp.sorted.narrowPeak;
        sort -k8,8nr {input.peaks_2} > {input.peaks_2}.temp.sorted.narrowPeak;
		{IDR} --samples  {input.peaks_1}.temp.sorted.narrowPeak {input.peaks_2}.temp.sorted.narrowPeak \
		      --input-file-type narrowPeak \
              --rank p.value \
              --output-file {output} \
              --plot \
              --log-output-file {log};
        rm {input.peaks_1}.temp*;
        rm {input.peaks_2}.temp*
		"""


# IDR is provided as log10, go back to actual percentage with 'bc -l' command
rule idr_significant_peaks:
    input: rules.idr.output
	output: PEAKS_DIR + "/{input_type}/Significant_peaks/{ab}/{tp}/{individual}_IDR_peaks_{idr_thres}.bed"
	message: "{ruleDisplayMessage}Significant peaks at {wildcards.idr_thres} for {wildcards.ab} at {wildcards.tp}; {wildcards.individual} ..."
	threads: 1
	shell:
		"""
        idr_transf=$(bc -l <<< "-l({wildcards.idr_thres})/l(10)");
        cat {input} | awk -v thres="$idr_transf" '$12 >= thres' | cut -f1-5 | sort -k1,1 -k2,2g > {output}
        """


########################################################
########### Self pseudo-replicates #####################
########################################################

#generate self pseudo-replicates
rule split_bam_files:
    input: BAM_DIR + "/{ab}.{tp}.{individual}_{rep}.filtered.rmdup_withUMI.bam"
    output:
        shuf_bam_1 = SCRATCH_DIR + "/{input_type}/pseudoreplicates/{ab}/{tp}/{individual}_{rep}_selfpseudorep1.filtered.rmdup_withUMI.bam",
        shuf_bam_2 = SCRATCH_DIR + "/{input_type}/pseudoreplicates/{ab}/{tp}/{individual}_{rep}_selfpseudorep2.filtered.rmdup_withUMI.bam"
    threads: 1
    shell: """
            {SAMTOOLS} view -H {input} > {output.shuf_bam_1}.temp_header.sam;

    		nlines=$({SAMTOOLS} view {input} | wc -l );
    		nlines=$(( (nlines + 1) / 2 ));

    		{SAMTOOLS} view {input} | shuf - > {output.shuf_bam_1}.temp_shuff.sam;
    		cat {output.shuf_bam_1}.temp_header.sam <(head -n $nlines {output.shuf_bam_1}.temp_shuff.sam) | {SAMTOOLS} view -b > {output.shuf_bam_1};
    		cat {output.shuf_bam_1}.temp_header.sam <(tail -n $nlines {output.shuf_bam_1}.temp_shuff.sam) | {SAMTOOLS} view -b > {output.shuf_bam_2};

    		rm {output.shuf_bam_1}.temp*
		  """

rule macs2_callpeak_selfpseudoreplicates:
    input:
        test1 = rules.split_bam_files.output.shuf_bam_1,
        test2 = rules.split_bam_files.output.shuf_bam_2,
        control = lambda wildcards: get_input_file(wildcards.tp, wildcards.individual, wildcards.input_type)
    output:
        peaks1 = temp(PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}_selfpseudorep1_peaks.narrowPeak"),
        peaks2 = temp(PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}_selfpseudorep2_peaks.narrowPeak")
    params:
        base_name1 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}_selfpseudorep1",
        base_name2 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_{rep}_selfpseudorep2",
        macs2_params = config["pipelines"]["4_call_peaks"]["macs2_relaxed4IDR"]
    threads: 1
    message: "{ruleDisplayMessage}Running macs2 analysis for {wildcards.ab} at {wildcards.tp}; {wildcards.individual}; {wildcards.rep} (self pseudo-replicates)..."
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}_{rep}_selfpseudorep.macs2.log"
    shell:
        """
        module load {MACS2_module};
        macs2 callpeak -t  {input.test1} -c {input.control} --name {params.base_name1} {params.macs2_params} 2> {log};
        macs2 callpeak -t  {input.test2} -c {input.control} --name {params.base_name2} {params.macs2_params} 2>> {log}
        """



rule idr_selfpseudoreplicates:
    input:
        peaks_1 = rules.macs2_callpeak_selfpseudoreplicates.output.peaks1,
        peaks_2 = rules.macs2_callpeak_selfpseudoreplicates.output.peaks2
    output: PEAKS_DIR + "/{input_type}/IDR/{ab}/{tp}/{individual}_selfpseudoreplicates{rep}_narrowPeak_idr.txt"
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}_{rep}.selfpseudoreplicates.idr.log"
    threads: 1
    message: "{ruleDisplayMessage}IDR on self-pseudoreplicates for {wildcards.ab} at {wildcards.tp}; {wildcards.individual}, {wildcards.rep} ..."
	shell:
		"""
        sort -k8,8nr {input.peaks_1} > {input.peaks_1}.temp.sorted.narrowPeak;
        sort -k8,8nr {input.peaks_2} > {input.peaks_2}.temp.sorted.narrowPeak;
		{IDR} --samples  {input.peaks_1}.temp.sorted.narrowPeak {input.peaks_2}.temp.sorted.narrowPeak \
		      --input-file-type narrowPeak \
              --rank p.value \
              --output-file {output} \
              --plot \
              --log-output-file {log};
        rm {input.peaks_1}.temp*;
        rm {input.peaks_2}.temp*
		"""

########################################################
########### Pooled pseudo-replicates ###################
########################################################


rule pool_and_split_bam_files:
	input:
		bam_1 = BAM_DIR + "/{ab}.{tp}.{individual}_1.filtered.rmdup_withUMI.bam",
		bam_2 = BAM_DIR + "/{ab}.{tp}.{individual}_2.filtered.rmdup_withUMI.bam"
	output:
		shuf_bam_1 = SCRATCH_DIR + "/{input_type}/pseudoreplicates/{ab}/{tp}/{individual}_pseudorep1.filtered.rmdup_withUMI.bam",
		shuf_bam_2 = SCRATCH_DIR + "/{input_type}/pseudoreplicates/{ab}/{tp}/{individual}_pseudorep2.filtered.rmdup_withUMI.bam"
	shell:
		"""
		{SAMTOOLS} merge {output.shuf_bam_1}.temp_merged.bam {input.bam_1} {input.bam_2};
		{SAMTOOLS} view -H {output.shuf_bam_1}.temp_merged.bam > {output.shuf_bam_1}.temp_header.sam;

		nlines=$({SAMTOOLS} view {output.shuf_bam_1}.temp_merged.bam | wc -l );
		nlines=$(( (nlines + 1) / 2 ));

		{SAMTOOLS} view {output.shuf_bam_1}.temp_merged.bam | shuf - > {output.shuf_bam_1}.temp_shuff.sam;
		cat {output.shuf_bam_1}.temp_header.sam <(head -n $nlines {output.shuf_bam_1}.temp_shuff.sam) | {SAMTOOLS} view -b > {output.shuf_bam_1};
		cat {output.shuf_bam_1}.temp_header.sam <(tail -n $nlines {output.shuf_bam_1}.temp_shuff.sam) | {SAMTOOLS} view -b > {output.shuf_bam_2};

		rm {output.shuf_bam_1}.temp*
		"""


rule macs2_callpeak_pseudoreplicates:
    input:
        test = SCRATCH_DIR + "/{input_type}/pseudoreplicates/{ab}/{tp}/{individual}_pseudorep{rep}.filtered.rmdup_withUMI.bam",
        control = lambda wildcards: get_input_file(wildcards.tp, wildcards.individual, wildcards.input_type)
        #control = BAM_DIR + "/input.{tp}.{individual}_1.keep.merge.rmdup.sort.bam"
    output:
        peaks = temp(PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_pseudorep{rep}_peaks.narrowPeak")
    params:
        base_name = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_pseudorep{rep}",
        macs2_params = config["pipelines"]["4_call_peaks"]["macs2_relaxed4IDR"]
    priority: 1
    threads: 1
    message: "{ruleDisplayMessage}Running macs2 analysis for {wildcards.ab} at {wildcards.tp}; {wildcards.individual} ..."
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}_pseudorep{rep}.macs2.log"
    shell:
        """
        module load {MACS2_module};
        macs2 callpeak -t  {input.test} -c {input.control} --name {params.base_name} {params.macs2_params} 2> {log}
        """


rule idr_pseudoreplicates:
    input:
        peaks_1 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_pseudorep1_peaks.narrowPeak",
        peaks_2 = PEAKS_DIR + "/{input_type}/macs2/{ab}/{tp}/{individual}_pseudorep2_peaks.narrowPeak"
    output: PEAKS_DIR + "/{input_type}/IDR/{ab}/{tp}/{individual}_pseudoreplicates_narrowPeak_idr.txt"
    log: LOG_DIR + "/{input_type}/{ab}.{tp}.{individual}.pseudoreplicates.idr.log"
    message: "{ruleDisplayMessage}IDR on pseudoreplicates for {wildcards.ab} at {wildcards.tp}; {wildcards.individual} ..."
	shell:
		"""
        sort -k8,8nr {input.peaks_1} > {input.peaks_1}.temp.sorted.narrowPeak;
        sort -k8,8nr {input.peaks_2} > {input.peaks_2}.temp.sorted.narrowPeak;
		{IDR} --samples  {input.peaks_1}.temp.sorted.narrowPeak {input.peaks_2}.temp.sorted.narrowPeak \
		      --input-file-type narrowPeak \
              --rank p.value \
              --output-file {output} \
              --plot \
              --log-output-file {log};
        rm {input.peaks_1}.temp*;
        rm {input.peaks_2}.temp*
		"""


########################################################
########### Summarize IDR results ######################
########################################################


rule summarize_and_plot_idr_results_per_condition:
    input:
        true_rep = expand(PEAKS_DIR + "/{{input_type}}/IDR/{{ab}}/{{tp}}/{individual}_narrowPeak_idr.txt", individual = individuals),
        self_pseudorep1 = expand(PEAKS_DIR + "/{{input_type}}/IDR/{{ab}}/{{tp}}/{individual}_selfpseudoreplicates1_narrowPeak_idr.txt", individual = individuals),
        self_pseudorep2 = expand(PEAKS_DIR + "/{{input_type}}/IDR/{{ab}}/{{tp}}/{individual}_selfpseudoreplicates2_narrowPeak_idr.txt", individual = individuals),
        pooled_pseudorep = expand(PEAKS_DIR + "/{{input_type}}/IDR/{{ab}}/{{tp}}/{individual}_pseudoreplicates_narrowPeak_idr.txt", individual = individuals)
    output:
        table = ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{input_type}/{ab}/{tp}/idr_significant_peaks.csv",
        plot = ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{input_type}/{ab}/{tp}/idr_significant_peaks.pdf"
    params:
        idr_dir = PEAKS_DIR + "/{input_type}/IDR/{ab}/{tp}",
        idr_thres = " ".join([str(x) for x in IDR_thres])
    message: "{ruleDisplayMessage}Summarize IDR results for {wildcards.ab} at {wildcards.tp}..."
    shell:
        """
        {Rscript}  {EXEC_DIR}/summarize_and_plot_idr_results_per_condition.R -i {params.idr_dir} \
        -o {output.table} -p {output.plot} --idr_thres '{params.idr_thres}' --ab {wildcards.ab} --tp {wildcards.tp}
        """


rule aggregate_idr_results:
    input: expand(ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{{input_type}}/{ab_tp}/idr_significant_peaks.csv", ab_tp = conditions)
    output: ANALYSIS_DIR + "/ChIPseq/IDR/with_indels/{input_type}/idr_significant_peaks.csv"
    shell:
        """
        head -1 --quiet {input} | uniq > {output};
        tail -n +2 --quiet {input} >> {output}
        """
