import sys
import os

from src import speciestree as spt

#################################################################################################
#Runs in the conda env paralogy_map + path to SCORPiOs scripts should be appended to python path #
#Dependencies (see env file) = snakemake, ete3, matpotlib and seaborn                            #
#################################################################################################

#TODO relocalize data

DUPLICATED_SPECIES = spt.get_species(config["species_tree"], config["ancestor"])
config["prune_ancestor"] = config.get("prune_ancestor", "Neopterygii")
REF_SPECIES = ["Oryzias.latipes", "Gasterosteus.aculeatus", "Tetraodon.nigroviridis", "Danio.rerio"]

assert "pre_dup_segments" in config, 'the pre_dup_segments arg is required'
SEGMENTS = config["pre_dup_segments"]

OUT = f'out_{config["jobname"]}'

SEGMENTS_CONV = SEGMENTS + '_conv'
SEGMENTS_FINAL = SEGMENTS + '_conv_final'

GENE_ID_CONV = config.get('gene_id_conv', None)

GENES = config["genes"]
RENAME_CHROMS = config.get("rename_chr", None)

#start from pre-computed ancgenes if provided (otherwise we'll extract them from the trees)
if "ancgenes" in config:
    ANCGENES = config["ancgenes"] #use pre-computed ancgene
else:
    ANCGENES = f'{OUT}/TGD_ancGenes.tsv' #file generated by the pipeline

random_arg = ''
RANDOM_START = config.get("random_start", False)
if RANDOM_START:
    random_arg += "--random_start"

rule all:
    """
    Target of the workflow:
    an .svg image for each dup species, with their genome colored by post-duplication chromosomes
    """
    input: f'{OUT}/figures/sptree_stats.pdf',
           f'{OUT}/figures/box_stats.svg'

rule extract_duplicated_ancGenes:
    """
    Extracts all post-duplication ancgenes in the input gene trees.
    """
    input: trees = config["forest"], sptree = config["species_tree"]
    output: ANCGENES
    shell:
        "python src/get_post_dup_ancgenes.py -t {input.trees} -d {config[ancestor]} "
        "-s {input.sptree} -o {output} --check_root"


rule convert_intervals:
    """
    Maps intervals of previous assembly used in input data to current assembly.
    Copies input file if no conversion is necessary.
    """
    input: s = SEGMENTS, g = GENES
    output: SEGMENTS_CONV
    params: id_history = GENE_ID_CONV
    run:

        if wildcards.ref_species in config["genes_conv"].keys():

            cmd1 = f"sed 's/Old stable ID, New stable ID, Release, Mapping score//g'\
                    {params.id_history} | grep -v '^[[:space:]]*$' > {params.id_history}.formatted;"

            os.system(cmd1)

            cmd2 = f'python src/convert_intervals.py -g {config["genes_conv"][wildcards.ref_species]}\
                    {input.g} -seg {input.s} -id {params.id_history}.formatted\
                    -f {config["gconv_format"]} {config["format"]} -o {output[0]}'

            os.system(cmd2)

        else:
            os.system(f"cp {input[0]} {output[0]}")

rule rename_anc_chr:
    """
    Renames ancestral chromosome in input data to be consistent with published Figures.
    Copies input file if no renaming is necessary.
    """
    input: segments = SEGMENTS_CONV
    output: segments = SEGMENTS_FINAL
    params: rename = RENAME_CHROMS
    run:
        if params.rename is None:
            shell(f"cp {input.segments} {output.segments}")

        else:
            with open(params.rename, 'r') as infile:
                d = {line.strip().split()[0]:line.strip().split()[1] for line in infile}

            with open(input.segments, 'r') as infile, open(output.segments, 'w') as out:
                for line in infile:
                    ch, start, stop, anc_chr = line.strip().split('\t')
                    anc_chr = d[anc_chr]
                    out.write('\t'.join([ch, start, stop, anc_chr])+'\n')


rule color_each_reference:
    """
    Identifies paralogous duplicated segments within each of the 4 reference species.
    Uses paralogous genes in input gene trees and pre-TGD segments.
    """
    input: segments = SEGMENTS_FINAL, genes = GENES, ancGenes = ANCGENES
    output: f'{OUT}/temp/{{ref_species}}_colors.txt'
    shell:
        "python src/color_reference_species.py -seg {input.segments} -g {input.genes} "
        "-ag {input.ancGenes} -o {output} -f {config[format]} {random_arg}"


def get_ref_maps():
    return expand(f'{OUT}/temp/{{ref_species}}_colors_homogenized.txt', ref_species=REF_SPECIES)


rule homogenize_references_ab:
    """
    Homogenizes nomenclature of paralogous segments within each references to be consistent across
    species (i.e orthologous segments should have the same name).
    """
    input:
        ref_colors = expand(f'{OUT}/temp/{{ref_species}}_colors.txt', ref_species=REF_SPECIES),
        genes = expand(GENES, ref_species=REF_SPECIES),
        ancGenes = ANCGENES
    output: get_ref_maps()
    params: guide = config.get('guide', "Gasterosteus.aculeatus")
    shell:
        "python src/homogenize_refs_colors.py -i {input.ref_colors} -ag {input.ancGenes} "
        "-guide_sp {params.guide} -g {input.genes} -f {config[format]}"


rule consensus_color_ancGene:
    """
    Assigns, through a majority vote of the 4 reference species descendant genes, an ancestral
    post-duplication chromosome to each ancestral gene.
    """
    input: ref_colors = get_ref_maps(),
           genes = expand(GENES, ref_species=REF_SPECIES),
           ancGenes = ANCGENES
    output: f'{OUT}/comparative_atlas.tsv'
    shell:
        "python src/color_ancgenes.py -ref {input.ref_colors} -g {input.genes} "
        "-ag {input.ancGenes} -o {output} -f {config[format]}"


rule draw_paralogy_map:
    """
    Draws the genome of each duplicated species colored by ancestral post-duplication chromosomes,
    using ancestral genes annotated in the previous rules.
    Drawing can be disabled in config, in this case only the coverage statistics are generated.
    """
    input: colors = f'{OUT}/comparative_atlas.tsv',
           genes = GENES.replace('{ref_species}', '{dup_species}')

    output: plot = report(f'{OUT}/figures/{{dup_species}}_ParalogyMap.svg', category="Comparative atlas"),
            stats = temp(f'{OUT}/{{dup_species}}_out_stats.txt')

    params: draw = config.get('draw', '')

    shell:
        "python src/plot_paralogy_map.py -c {input.colors} -g {input.genes} -o {output.plot} "
        "-s {wildcards.dup_species} -f {config[format]} -os {output.stats} {params.draw}"


rule stats:
    input:
        st = expand(f'{OUT}/{{dup_species}}_out_stats.txt', dup_species=DUPLICATED_SPECIES),
        fig = expand(f'{OUT}/figures/{{dup_species}}_ParalogyMap.svg', dup_species=DUPLICATED_SPECIES)

    output: f'{OUT}/out_stats.txt'

    shell:
        "cat {input.st} > {output}"


if "comparisons" in config:
    rule plot_annotation_statistics1:
        """
        Plots annotation coverage.
        """
        input: stats = f'{OUT}/out_stats.txt',

        output: boxplots = report(f'{OUT}/figures/box_stats.svg', category="Annotation statistics"),
                sptree = report(f'{OUT}/figures/sptree_stats.pdf', category="Annotation statistics")
                       
        shell:
            "python src/draw_species_tree_stats.py -i {input.stats} {config[comparisons]} "
            "-l {config[labels]} -s {config[species_tree]} -ob {output.boxplots} "
            "-os {output.sptree} -a {config[prune_ancestor]} -da {config[ancestor]}"

else:
    rule plot_annotation_statistics2:
        """
        Plots annotation coverage.
        """
        input: stats = f'{OUT}/out_stats.txt',

        output: boxplots = report(f'{OUT}/figures/box_stats.svg', category="Annotation statistics"),
                sptree = report(f'{OUT}/figures/sptree_stats.pdf', category="Annotation statistics")
                       
        shell:
            "python src/draw_species_tree_stats.py -i {input.stats} -s {config[species_tree]} "
            "-ob {output.boxplots} -os {output.sptree} -a {config[prune_ancestor]} "
            "-da {config[ancestor]}"
