import os
import sys
import re
import yaml
#import glob3
import pandas as pd
import glob
from snakemake.utils import min_version
# Enforce a minimum Snakemake version
min_version("4.7")

onstart:
    print("##########################################\n")
    print("# STARTING PIPELINE\n")
    print("##########‰################################\n")
#    print ("Running ChIP-seq pre-processing workflow for the following samples:\n " + ' \n '.join(map(str, SAMPLES)))

onsuccess:
    print("##########################################\n")
    print("# PIPELINE ENDED SUCCESSFULLY \n")
    print("##########################################\n")

ruleDisplayMessage = """\n\n####### START RULE EXECUTION ##########\n"""

# we focus on stringent VCF set for the CHT (same as for AS counts)



################################################################################
######################## PROJECT STRUCTURE #####################################
################################################################################

### FOLDERS ON TIER1 ###

PROJECT_DIR = config["global"]["projectdir"]
DATA_DIR  = PROJECT_DIR + "/" + config["project_structure"]["datadir"]
ANALYSIS_DIR  = PROJECT_DIR + "/" + config["project_structure"]["analysisdir"]
LOG_DIR = PROJECT_DIR + "/" + config["project_structure"]["logdir"]
EXEC_DIR = PROJECT_DIR + "/" + config["project_structure"]["utils"]

################################################################################
######################## FUNCTIONS #############################################
################################################################################

sys.path.append(EXEC_DIR)
import utils as utl

def get_individuals_list(tp):
    if tp in ["24", "68"]:
        individuals_path = config['data']['samples']['individual_ids']
        individuals = open(individuals_path, 'r').read().splitlines()
    elif tp == "1012":
        individuals_path = config['data']['samples']['individual_ids_1012h']
        individuals = open(individuals_path, 'r').read().splitlines()
    else:
        raise KeyError("Incorrect time-point provided")
    return individuals

###############################################################################
######################## READ DATA #############################################
################################################################################


# alternative way to get config
# config_path = "/g/furlong/project/68_F1_cisreg_ichip/config/config.yml"
# with open(config_path, 'r') as stream:
#     config = yaml.load(stream)

# Parameters for consensus peak set - used here to define directory
idr = config["pipelines"]["3_call_peaks"]["consensus_set"]["idr_thres"]
n_ind = config["pipelines"]["3_call_peaks"]["consensus_set"]["num_ind"]
input = config["pipelines"]["3_call_peaks"]["consensus_set"]["input_type"]
consensus_type_dir = input + "_idr" + str(idr) + "_ind" + str(n_ind)

# width or regions around peaks to look for no_variants
region_width = config["pipelines"]["4_CHT"]["peaks_resize_width"]


# pipeline-specific directories
BAM_DIR = DATA_DIR + "/ChIPseq/Alignments/with_indels"
PEAKS_DIR = DATA_DIR + "/ChIPseq/Peaks/with_indels"
AS_COUNTS_DIR = DATA_DIR + "/ChIPseq/AS_counts/with_indels"
AS_REGION_COUNTS_DIR = AS_COUNTS_DIR + "/2_target_regions/by_peaks_width" + str(region_width) + "/" + consensus_type_dir # counts per peak, variant s in resized windor
#AS_REGION_COUNTS_DIR = AS_COUNTS_DIR + "/test/2_target_regions/by_peaks_width" + str(region_width) + "/" + consensus_type_dir # counts per peak, variant s in resized windor
CHT_DIR = DATA_DIR + "/ChIPseq/CHT/with_indels/by_peaks_width" + str(region_width) + "/" + consensus_type_dir



### SOFTWARE ###
WASP = config["tools"]["by_path"]["wasp"] # original version of WASP
WASP_indels = config["tools"]["by_path"]["wasp_indels"]

### DATA ###
# TFs and time-points
ab_tp_list  = ["mef2/68", "bin/68", "ctcf/68", "zld/24", "twi/24"]

# Individuals - full list (same for 6-8h and 2-4h)
individuals = get_individuals_list("68")
# individuals_path = config['data']['samples']['individual_ids']
# individuals = open(individuals_path, 'r').read().splitlines()

# separately run for 10-12h (exclude one line)
ab_tp_list_1012 = ["mef2/1012", "bin/1012"]
individuals_1012 = get_individuals_list("1012")
# individuals_path_1012 = config['data']['samples']['individual_ids_1012h']
# individuals_1012 = open(individuals_path_1012, 'r').read().splitlines()

print(len(individuals))
print(len(individuals_1012))

# Genome
chrom_sizes = config["data"]["genome"]["dm6"]["chrom_sizes"]
fasta_dir = config["data"]["genome"]["dm6"]["fasta_by_chr"]
chrom_list = config["pipelines"]["4_CHT"]["chrom_list"].split(" ")
chr_fasta = [fasta_dir + "/" + chr + ".fasta" for chr in chrom_list] # to select chromosomes from fasta


# Haplotypes (phased VCF)
#config["data"]["phased_VCF"]["stringent_with_gl"]



################################################################################
######################## MAIN ##################################################
################################################################################

localrules: all, make_cht_input_files, make_adj_in_out_files, qqplot

rule all:
    input:
        # expand(AS_REGION_COUNTS_DIR + "/{ab_tp}/input_regions.txt", ab_tp = ab_tp_list),
        # expand(AS_REGION_COUNTS_DIR + "/{ab_tp}/target_regions.txt.gz", ab_tp = ab_tp_list),
        # expand(AS_REGION_COUNTS_DIR + "/{ab_tp}/raw_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", ab_tp = ab_tp_list, individual = individuals),
        # expand(AS_REGION_COUNTS_DIR + "/{ab_tp}/adjust_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", ab_tp = ab_tp_list, individual = individuals),
        # expand(CHT_DIR + "/{ab_tp}/cht_results.txt", ab_tp = ab_tp_list),
        # expand(CHT_DIR + "/{ab_tp}/qqplot.png", ab_tp = ab_tp_list)
        # expand(AS_REGION_COUNTS_DIR + "/{ab_tp}/adjust_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", ab_tp = ab_tp_list_1012, individual = individuals_1012),
        # expand(CHT_DIR + "/{ab_tp}/cht_results.txt", ab_tp = ab_tp_list_1012),
        expand(CHT_DIR + "/{ab_tp}/qqplot.png", ab_tp = ab_tp_list_1012)


######### AS counts in target regions #######################

rule prepare_target_regions:
    input: expand(PEAKS_DIR + "/{input_type}/Consensus_peaksets/{{ab}}/{{tp}}/consensus_peaks_idr{idr_thres}_ind{num_ind}.bed", input_type = input, idr_thres = idr, num_ind = n_ind)
    output: AS_REGION_COUNTS_DIR + "/{ab}/{tp}/input_regions.txt"
    params:
        width = region_width
    message: "{ruleDisplayMessage}Prepare target regions for {wildcards.ab} at {wildcards.tp}"
    shell:
        """
        Rscript {EXEC_DIR}/prepare_target_regions.R -p {input} -o {output} -w {params.width}
        """


rule get_target_regions:
    """Extract peak regions (and corresponding variants) that have sufficient read depth
    and count of allele specific reads."""
    input:
        variants = lambda wildcards: expand(AS_COUNTS_DIR + "/1_as_counts_by_variant/{{ab}}.{{tp}}.{individual}.variant_counts.txt.gz", individual = get_individuals_list(wildcards.tp)),
        bams = lambda wildcards: expand(BAM_DIR + "/{{ab}}.{{tp}}.{individual}.filtered.rmdup_withUMI.bam", individual = get_individuals_list(wildcards.tp)),
        regions = rules.prepare_target_regions.output
    output: AS_REGION_COUNTS_DIR + "/{ab}/{tp}/target_regions.txt.gz"
    params:
        min_read_count = config["pipelines"]["4_CHT"]["min_read_count"],
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"],
        min_het_count = config["pipelines"]["4_CHT"]["min_het_count"],
        min_minor_allele_count = config["pipelines"]["4_CHT"]["min_minor_allele_count"]
    message: "{ruleDisplayMessage}Get target regions for {wildcards.ab} at {wildcards.tp}"
    shell:
        """python {WASP_indels}/CHT_input/get_target_regions.py \
            --variants {input.variants} \
            --bams {input.bams} \
            --regions {input.regions} \
            --min_total_reads {params.min_read_count} \
            --min_as_reads {params.min_as_count} \
            --min_het {params.min_het_count} \
            --min_minor {params.min_minor_allele_count} \
            --outfile {output}"""


rule extract_haplotype_read_counts:
    """Extract haplotype read counts for target regions for each individual."""
    input:
        variants = AS_COUNTS_DIR + "/1_as_counts_by_variant/{ab}.{tp}.{individual}.variant_counts.txt.gz",
        bam = BAM_DIR + "/{ab}.{tp}.{individual}.filtered.rmdup_withUMI.bam",
        target_regions = rules.get_target_regions.output
    output: AS_REGION_COUNTS_DIR + "/{ab}/{tp}/raw_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz"
    message: "{ruleDisplayMessage}Get haplotype counts for {wildcards.individual}: {wildcards.ab} at {wildcards.tp}"
    shell:
        """python {WASP_indels}/CHT_input/get_region_data.py \
            --adjust_hetp as_counts \
            --variants {input.variants} \
            --bam {input.bam} \
            --regions {input.target_regions} \
            --outfile {output}"""



######## adjust read counts ##############


rule fasta2h5:
    """Create HDF5 file from genome FASTA files"""
    input:
        fasta = chr_fasta,
        chrom = chrom_sizes
    output: DATA_DIR + "/Genome/fasta_h5/seq.h5"
    shell:
        """{WASP}/snp2h5/fasta2h5 --chrom {input.chrom} --seq {output} {input.fasta}
        """


rule make_adj_in_out_files:
    """makes input / output files containing lists of adjusted read count files"""
    input:
        lambda wildcards: expand(AS_REGION_COUNTS_DIR + "/{{ab}}/{{tp}}/raw_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", individual = get_individuals_list(wildcards.tp))
    output:
        in_file = AS_REGION_COUNTS_DIR + "/{ab}/{tp}/adjust_hap_read_counts_hetp_adj/input.txt",
        out_file = AS_REGION_COUNTS_DIR + "/{ab}/{tp}/adjust_hap_read_counts_hetp_adj/output.txt"
    message: "{ruleDisplayMessage}Prepare file names: {wildcards.ab} at {wildcards.tp}"
    shell:
        "ls {input} > {output.in_file};"
        "ls {input} | sed 's/raw_hap_read_counts/adjust_hap_read_counts/' > {output.out_file};"
        "sleep 10;"

# --skip 30 - also doesn't adjust all regions (writes subset of regions to file)
rule adjust_read_counts:
    input:
        in_file = rules.make_adj_in_out_files.output.in_file,
        out_file = rules.make_adj_in_out_files.output.out_file,
        seq = rules.fasta2h5.output
    output:
        expand(AS_REGION_COUNTS_DIR + "/{{ab}}/{{tp}}/adjust_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", individual = individuals)
    wildcard_constraints:
        tp = "24|68"
    threads: 16
    message: "{ruleDisplayMessage}Correct haplotype counts: {wildcards.ab} at {wildcards.tp}"
    shell:
        """python -u {WASP}/CHT/update_total_depth.py  \
           --seq {input.seq}  {input.in_file} {input.out_file}"""

# have to do this trick because output doesn't accept functions - and we need to set different set of individuals at 10-12h
rule adjust_read_counts_1012h:
    input:
        in_file = rules.make_adj_in_out_files.output.in_file,
        out_file = rules.make_adj_in_out_files.output.out_file,
        seq = rules.fasta2h5.output
    output:
        expand(AS_REGION_COUNTS_DIR + "/{{ab}}/{{tp}}/adjust_hap_read_counts_hetp_adj/haplotype_read_counts.{individual}.txt.gz", individual = individuals_1012)
    wildcard_constraints:
        tp = "1012"
    threads: 16
    message: "{ruleDisplayMessage}Correct haplotype counts: {wildcards.ab} at {wildcards.tp}"
    shell:
        """python -u {WASP}/CHT/update_total_depth.py  \
           --seq {input.seq}  {input.in_file} {input.out_file}"""


# ######## CHT #############################

rule make_cht_input_files:
    """make the input file containing list of read count files
    that is used for running the combined test"""
    input: lambda wildcards: rules.adjust_read_counts_1012h.output if wildcards.tp == "1012" else rules.adjust_read_counts.output
    output: CHT_DIR + "/{ab}/{tp}/cht_input_files.txt"
    message: "{ruleDisplayMessage}Make list of CHT input file: {wildcards.ab} at {wildcards.tp}"
    shell: """ls {input} > {output} &&
              sleep 10
           """

# --sample 2000 --seed 1234 - previous parameters, using --skip instead
#   --skip 50 \
rule fit_bnb_coef:
    """estimate dispersion parameters for beta-negative binomial
    part of combined test"""
    input: rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/bnb_coef.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    message: "{ruleDisplayMessage}Fit BNB coefficients: {wildcards.ab} at {wildcards.tp}"
    shell:
        """python {WASP}/CHT/fit_bnb_coefficients.py \
           --min_as_counts {params.min_as_count} \
           --skip 30 \
           {input} {output}
        """


rule fit_as_coef:
    """estimate dispersion parameters for allele-specific
    part of combined test, only autosomes"""
    input: rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/as_coef.txt"
    message: "{ruleDisplayMessage}Fit AS coefficients: {wildcards.ab} at {wildcards.tp}"
    shell:
        """python {WASP}/CHT/fit_as_coefficients_mod.py \
           -r --chrom_list chrX chrY chrM \
           -i {input} -o {output}
        """



rule combined_test:
    input:
        as_coef = rules.fit_as_coef.output,
        bnb_coef = rules.fit_bnb_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results.txt"
    message: "{ruleDisplayMessage}Run CHT: {wildcards.ab} at {wildcards.tp}"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python {WASP}/CHT/combined_test.py \
            --min_as_counts {params.min_as_count} \
            --bnb_disp {input.bnb_coef} \
            --as_disp {input.as_coef} \
            {input.cht_input} {output}
        """


rule combined_test_permuted:
    """Run the combined test on permuted genotypes"""
    input:
        as_coef = rules.fit_as_coef.output,
        bnb_coef = rules.fit_bnb_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results_permuted.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python {WASP}/CHT/combined_test.py --shuffle \
            --min_as_counts {params.min_as_count} \
            --bnb_disp {input.bnb_coef} \
            --as_disp {input.as_coef} \
            {input.cht_input} {output}
        """

rule as_test:
    """run just the allele-specific part of the combined test"""
    input:
        as_coef = rules.fit_as_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results_as.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python {WASP}/CHT/combined_test.py \
            --min_as_counts {params.min_as_count} \
            --as_only --as_disp {input.as_coef} \
            {input.cht_input} {output}
        """


rule bnb_test:
    """run just the beta-negative-binomial part of the combined test"""
    input:
        bnb_coef = rules.fit_bnb_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results_bnb.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python {WASP}/CHT/combined_test.py \
            --min_as_counts {params.min_as_count} \
            --bnb_only --bnb_disp {input.bnb_coef} \
            {input.cht_input} {output}
        """

rule as_test_permuted:
    """run just the allele-specific part of the combined test
    on permuted genotypes"""
    input:
        as_coef = rules.fit_as_coef.output,
        bnb_coef = rules.fit_bnb_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results_as_permuted.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python3.7 {WASP}/CHT/combined_test.py --shuffle \
            --min_as_counts {params.min_as_count} \
            --as_only --as_disp {input.as_coef} \
            {input.cht_input} {output}
        """

rule bnb_test_permuted:
    """run just the beta-negative-binomial part of the combined test with
    permuted genotypes
    """
    input:
        bnb_coef = rules.fit_bnb_coef.output,
        cht_input = rules.make_cht_input_files.output
    output: CHT_DIR + "/{ab}/{tp}/cht_results_bnb_permuted.txt"
    params:
        min_as_count = config["pipelines"]["4_CHT"]["min_as_count"]
    shell:
        """python {WASP}/CHT/combined_test.py --shuffle \
            --min_as_counts {params.min_as_count} \
            --bnb_only --bnb_disp {input.bnb_coef} \
            {input.cht_input} {output}
        """


rule qqplot:
    input:
        rules.combined_test.output,
        rules.as_test.output,
        rules.bnb_test.output,
        rules.combined_test_permuted.output,
        rules.as_test_permuted.output,
        rules.bnb_test_permuted.output
    output:
        CHT_DIR + "/{ab}/{tp}/qqplot.png"
    shell:
        """Rscript-3.5.0 --vanilla {WASP}/CHT/qqplot_cairo.R {output} {input};
           sleep 10
        """
