import pandas as pd


# Variables
TOP_VARIANTS = pd.read_csv("top_predicted_variants_Basenji_AI.txt", index_col=0, header=0, sep="\t")
TOP_VARIANTS['chr'] = [i.split(":")[0] for i in TOP_VARIANTS['variant_ID'].tolist()]
TOP_VARIANTS['pos'] = [i.split(":")[1] for i in TOP_VARIANTS['variant_ID'].tolist()]
TOP_VARIANTS_TUPLE = [(l["chr"], l["pos"], l["condition"]) for i, l in TOP_VARIANTS.iterrows()]

# Annotation
GENOME_FASTA = "/g/furlong/genome/D.melanogaster/Dm6/fasta/dm6.UCSC.noMask.fa"

# Paths
PREDICTIONS_DIR = "/g/furlong/project/103_Basenji/Mattia/analysis/saturation_scores"

# Tools
SATURATION_SCORES_PY = "/g/furlong/forneris/software/basenji/bin/basenji_sat_bed.py"
COMPUTE_DIFF_SCORES_PY = "/g/furlong/project/103_Basenji/Mattia/src/py/read_sat_mutagenesis.py"
RENDER_SATURATION_MUTAGENESIS = "/g/furlong/project/103_Basenji/Mattia/Rmd/plot_saturation_mutagenesis/render_saturation_mutagenesis.R"

# Target
SAT_SCORES_OUT = [PREDICTIONS_DIR + "/tables/saturation_score_center_" + chrom + "_" + pos + \
        "_condition_" + cond + "_model.h5" for chrom, pos, cond in TOP_VARIANTS_TUPLE]
SAT_SCORES_SIZE500_OUT = [PREDICTIONS_DIR + "/tables_size500/saturation_score_center_" + chrom + "_" + pos + \
        "_condition_" + cond + "_model.h5" for chrom, pos, cond in TOP_VARIANTS_TUPLE]
PLOT_SATURATION_OUT = [PREDICTIONS_DIR + "/plots/plot_saturation_score_center_" + chrom + "_" + pos + \
        "_condition_" + cond + ".html" for chrom, pos, cond in TOP_VARIANTS_TUPLE]


rule all:
    input:
        SAT_SCORES_OUT, PLOT_SATURATION_OUT#, SAT_SCORES_SIZE500_OUT



rule compute_saturation_scores:
    input:
        json = "/g/furlong/project/103_Basenji/models/params_drosophila_l131k.json",
        model = "/g/furlong/project/103_Basenji/models/drosophila_l131k_augmented/model_best.h5",
        regions = "/g/furlong/project/103_Basenji/Mattia/src/saturation_scores/top_predicted_variants_Basenji_AI.txt",
        model_samples = "/g/furlong/project/103_Basenji/models/drosophila_l131k_augmented/acc.txt"
    output:
        diff = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_NormalizedDifferenceScores.txt", path=PREDICTIONS_DIR),
        importance_pos = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresPos.txt", path=PREDICTIONS_DIR),
        importance_neg = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresNeg.txt", path=PREDICTIONS_DIR),
        ref_seq = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ReferenceSequence.txt", path=PREDICTIONS_DIR),
        model = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_model.h5", path=PREDICTIONS_DIR)
    params:
        fa = GENOME_FASTA,
        size = 150,
        stat = "sum",
        out_root = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}", path=PREDICTIONS_DIR),
        cond = "{cond}",
        chrom = "{chrom}",
        pos = "{pos}"
    resources:
        cpu = 1,
        memPerCpu = 4000,
        time = 360
    conda: "basenji_env.yml"
    shell:
        """
        BED_REGION=`awk -v c={params.chrom} -v p={params.pos} 'BEGIN {{print c "\\t" p-1 "\\t" p}}'`
        SAMPLE_ID=`cat {input.model_samples} | awk '$4 == "{params.cond}"' | cut -f1`

        python {SATURATION_SCORES_PY} -f {params.fa} -l {params.size} \
        --stats {params.stat} -o {output.diff}.temp/ \
        {input.json} {input.model} <(echo -e $BED_REGION) 

        python {COMPUTE_DIFF_SCORES_PY} -i {output.diff}.temp/scores.h5 \
        -s $SAMPLE_ID -o {params.out_root}

        cp {output.diff}.temp/scores.h5 {output.model}

        rm -fr {output.diff}.temp*
        """


rule compute_saturation_scores_500size:
    input:
        json = "/g/furlong/project/103_Basenji/models/params_drosophila_l131k.json",
        model = "/g/furlong/project/103_Basenji/models/drosophila_l131k_augmented/model_best.h5",
        regions = "/g/furlong/project/103_Basenji/Mattia/src/saturation_scores/top_predicted_variants_Basenji_AI.txt",
        model_samples = "/g/furlong/project/103_Basenji/models/drosophila_l131k_augmented/acc.txt"
    output:
        diff = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_NormalizedDifferenceScores.txt", path=PREDICTIONS_DIR),
        importance_pos = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresPos.txt", path=PREDICTIONS_DIR),
        importance_neg = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresNeg.txt", path=PREDICTIONS_DIR),
        ref_seq = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ReferenceSequence.txt", path=PREDICTIONS_DIR),
        model = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_model.h5", path=PREDICTIONS_DIR)
    params:
        fa = GENOME_FASTA,
        size = 500,
        stat = "sum",
        out_root = expand("{path}/tables_size500/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}", path=PREDICTIONS_DIR),
        cond = "{cond}",
        chrom = "{chrom}",
        pos = "{pos}"
    resources:
        cpu = 1,
        memPerCpu = 4000,
        time = 1800
    conda: "basenji_env.yml"
    shell:
        """
        BED_REGION=`awk -v c={params.chrom} -v p={params.pos} 'BEGIN {{print c "\\t" p-1 "\\t" p}}'`
        SAMPLE_ID=`cat {input.model_samples} | awk '$4 == "{params.cond}"' | cut -f1`

        python {SATURATION_SCORES_PY} -f {params.fa} -l {params.size} \
        --stats {params.stat} -o {output.diff}.temp/ \
        {input.json} {input.model} <(echo -e $BED_REGION) 

        python {COMPUTE_DIFF_SCORES_PY} -i {output.diff}.temp/scores.h5 \
        -s $SAMPLE_ID -o {params.out_root}

        cp {output.diff}.temp/scores.h5 {output.model}

        rm -fr {output.diff}.temp*
        """




# To run this rule use `conda activate R-4.2.2` and run within this env

rule render_saturation_scores_plot:
    input:
        diff = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_NormalizedDifferenceScores.txt", path=PREDICTIONS_DIR),
        importance_pos = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresPos.txt", path=PREDICTIONS_DIR),
        importance_neg = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ImportanceScoresNeg.txt", path=PREDICTIONS_DIR),
        ref_seq = expand("{path}/tables/saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}_ReferenceSequence.txt", path=PREDICTIONS_DIR)
    output:
        expand("{path}/plots/plot_saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}.html", path=PREDICTIONS_DIR)
    resources:
        cpu = 1,
        memPerCpu = 1000,
        time = 15
    conda: "R-4.2.2.yml"
    params:
        images_folder = expand("{path}/plots/plot_saturation_score_center_{{chrom}}_{{pos}}_condition_{{cond}}", path=PREDICTIONS_DIR),
        title = "{cond}:{chrom}_{pos}"
    shell:
        """
        Rscript {RENDER_SATURATION_MUTAGENESIS} \
        -d {input.diff} -p {input.importance_pos} \
        -n {input.importance_neg} -r {input.ref_seq} \
        -i {params.images_folder} -t {params.title} \
        -o {output}
        """
