import pandas as pd
from os.path import join
import yaml


df = pd.read_csv(config['sra_file'], sep="\t") ##expect a format for this file
sras=df['RUN'].to_list()
#samples=

paired = config['paired']
out_path = config['main_path']
fastq_path = join(out_path, "fastq")
fastqc_path = join(out_path, "fastqc")
if "fastq_path" in config.keys():
    fastq_path = config['fastq_path']

sal_dir=join(out_path, "sal_out")
if "sal_out" in config.keys():
    sal_dir=config['sal_out']

tf=16
if "tf" in config.keys():
    tf = config['tf'] ##thinning factor

post_type="gibbs"
if "posterior_type" in config.keys():
    post_type = config["posterior_type"]

npost=100
if "npost" in config.keys():
    npost = config["npost"]

gc_bias=True
if "gc_bias" in config.keys():
    gc_bias = config["gc_bias"]

sal_dir=join(sal_dir,"mode_gcbias="+str(gc_bias))
sal_dir=join(sal_dir, "posttype={}_npost={}_tf={}".format(post_type, npost, tf))
sal_quant_dir = join(sal_dir, "{sample}")
sal_quant_file = join(sal_quant_dir, "quant.sf")

orig_sal_dir = join(config['orig_path'], "sal_out")
orig_sal_dir = join(orig_sal_dir, "mode_gcbias="+str(gc_bias))
orig_sal_dir = join(orig_sal_dir, "posttype={}_npost={}_tf={}".format(post_type, npost, tf))
orig_quant_dir = join(orig_sal_dir, "{sample}")
orig_quant_file = join(orig_quant_dir, "quant.sf")

multiqc_path=join(out_path, "multiqc")
multiqc_file = join(multiqc_path, "multiqc_report.html")

term_dir=join(out_path, "term_out")
if "term_out" in config.keys():
    term_dir=config["term_out"]
term_dir=join(term_dir,"mode_gcbias="+str(gc_bias))
term_dir=join(term_dir, "posttype={}_npost={}_tf={}".format(post_type, npost, tf))

term_nothr0_dir = join(term_dir, "no_threshold0")
term_nothr_dir = join(term_dir, "no_threshold")
term_old_dir = join(term_dir, "old") ###Hirak terminus

term_nothr0_samp = join(term_nothr0_dir, "{sample}")
term_old_samp = join(term_old_dir, "{sample}")


term_nothr0_group = join(term_nothr0_samp, "group_nwk.txt")
term_old_group = join(term_old_samp, "groups.txt")

term_nothr0_clust = join(term_nothr0_dir, 'cluster_nwk.txt')
term_old_clust = join(term_old_samp, "clusters.txt")


#### Only one sample
def fastq_inp_one(wildcards, fast_path):
    sample=wildcards.sample
    if paired:
        return [join(fast_path, sample +"_1.fastq"), 
        join(fast_path, sample+"_2.fastq")]
    else:
        return [join(fast_path,sample+".fastq")]
        
#### Multi sample
def fastq_inp_multi(wildcards, fast_path):
    if paired:
        return expand(join(fast_path,"{sample}_1.fastq".format(wildcards.sample)), sample=sras).extend(
            expand(join(fast_path,"{sample}_2.fastq".format(wildcards.sample)), sample=sras))
        
    else:
        return expand(join(fast_path,"{sample}.fastq"), sample=sras)

rule run_terminus:
    input:
        nothr0 = term_nothr0_clust,
        thr_hirak = expand(term_old_clust, sample=sras),
        mqc = multiqc_file

rule run_term_nothr0_collapse:
    input:
        expand(term_nothr0_group, sample=sras)
    output:
        term_nothr0_clust
    params:
        treeterm_path = config["treeterm_path"],
        input = sal_dir,
        output = term_nothr0_dir,
    shell:
        """
            cd {params.treeterm_path}
            target/release/treeterminus consensus -d {params.input} -o {params.output}
            cd -
        """

rule run_term_nothr0_group:
    input:expand(term_nothr0_group, sample=sras)

rule _run_term_nothr0_group:
    input:
        sal_quant_dir
    output:
        term_nothr0_group
    params:
        treeterm_path=config["treeterm_path"],
        output = term_nothr0_dir,
        #input = join(sal_quant_dir,  "{rep}_{cond}")
    shell:
        """
            cd {params.treeterm_path}
            target/release/treeterminus group -d {input} -o {params.output}  --thr false --mean_inf false
            cd -
        """

rule run_term_collapse:
    input:
        inp_term = expand(term_old_group, sample = sras),
        inp_sal = expand(sal_quant_dir, sample = sras)
    output:
        expand(term_old_clust, sample = sras)
    params:
        term_path = config["term_path"],
        output = term_old_dir,
    shell:
        """
            cd {params.term_path}
            target/release/terminus collapse -c 0.5 -d {input.inp_sal} -o {params.output}
            cd -
        """

rule run_term_group:
    input:expand(term_old_group, sample = sras)

rule _run_term_group:
    input:sal_quant_dir
    output:term_old_group
    params:
        term_path = config["term_path"] + "/target/release/terminus",
        output = term_old_dir,
    shell:
        """
          {params.term_path} group -m 0.1 --tolerance 0.001 -d {input} -o {params.output}
        """


rule multiqc:
    input:
        sal_quants=expand(sal_quant_file, sample=sras),
        # fastqc=expand(fastqc_file, sample=sras)
    output:
        multiqc_file
    params:
        multiqc_out_dir = multiqc_path
    shell:
        """
            module load multiqc
            multiqc . -o {params.multiqc_out_dir}
        """

rule copy_salmon:
    input:expand(sal_quant_file, sample=sras)

rule _copy_salmon:
    input:
        inp_fastq = lambda wc:fastq_inp_one(wc, fastq_path),
        orig_quant_file = orig_quant_file
    output:
        out_sal = sal_quant_file
    resources: cpus=10, mem=32000
    params:
        out_dir = sal_dir,
        orig_quant_dir = orig_quant_dir

    run:
        shell("cp -rf {params.orig_quant_dir} {params.out_dir}")