# Snakemake pipeline for simulating and quantifying brain gene expression experiment 
# with TreeTerminus

# Inputs

# fc.min <- as.numeric(args[1])
# fc.max <- as.numeric(args[2])
# frag.len <- as.numeric(args[3])
# lib.size <- as.numeric(args[4])
# seed <- as.numeric(args[5])
# set.seed(seed)

# fasta.fp <- args[6]
# txdb.fp <- args[7]      # sqlite file
# cortexTpm.fp <- args[8]
# meanDisp.fp <- args[9] # dispersion
# fitPar.fp <- args[10]   # gc bias params

# output.txps.fp <- args[11]
# output.expr.fp <- args[12]

from os.path import join

PREFIX = config.get('output_prefix', '')
SEED = config['seed']

REPLICATES = config['replicates']

FC_MIN = config['fc_min']
FC_MAX = config['fc_max']
FRAG_LEN = config['frag_len']
LIB_SIZE = config['lib_size']
TXP_FASTA_FP = config['txp_fasta']
TXDB_FP = config['txdb']
CORTEX_TPM_FP = config['cortex_tpms']
MEANDISP_FP = config['meandisp']
GC_PARAMS = config['gc_params']
USE_GC = config['gc_bias']

output_dir = join(PREFIX, 'output', f'seed={SEED}_fc={FC_MIN}:{FC_MAX}')

# 1. Expr simulation outputs
expr_rda_fp = join(output_dir, 'sim_expr.rda')
out_txps_fa_fp = join(output_dir, 'transcripts.fa')

# 2. Read simulation outputs
reads_dir = join(output_dir, 'reads_gc={}'.format(str(USE_GC)))
read_rep_dir = join(reads_dir, '{rep}')
read_fasta_file=join(reads_dir, '{rep}', 'sample_0{cond}_{read}.fasta')
read_fasta_file_rgen=join(reads_dir, '{{rep}}', 'sample_0{cond}_{read}.fasta')
read_shuff_fasta_file_rgen =join(reads_dir, '{{rep}}', 'sample_0{cond}_{read}_shuffled.fa.gz')
read_fasta_file_sal        =join(reads_dir, '{{rep}}', 'sample_0{{cond}}_{read}_shuffled.fa.gz')

# 3. Salmon outputs

nrep=config['nreplicates'] ##number of posterior samples
post_type=config['posterior_type']

SAL_IDX_PATH = config['sal_idx_path']
SAL_PATH = config['sal_path']

tf = -1
if(post_type=='gibbs'):
    tf=config['thinning_factor']
    post_dir=join(output_dir, 'post_type={}_nrep={}_tf={}').format(post_type, nrep, tf) ##thinning factor additional for gibbs
else:
    post_dir=join(output_dir, 'post_type={}_nrep={}').format(post_type, nrep)

sal_dir = join(post_dir, 'salmon_quants')
sal_quant_dir = join(sal_dir, '{rep}_{cond}')
sal_quant_file = join(sal_quant_dir, 'quant.sf')

read_fasta_file_pe1 = read_fasta_file_sal.format(read=1)
read_fasta_file_pe2 = read_fasta_file_sal.format(read=2)

##
# Terminus and Tree terminus
##

term_dir = join(post_dir, 'terminus')
treeterm_nothr_dir = join(term_dir, 'no_threshold')
treeterm_nothr0_dir = join(term_dir, 'no_threshold0')
treeterm_thr_dir = join(term_dir, 'threshold')
treeterm_mean_inf0_dir = join(term_dir, 'no_threshold_meaninf_0')
term_orig_dir = join(term_dir, 'old') ### terminus

treeterm_nothr_sample = join(treeterm_nothr_dir, '{rep}_{cond}')
treeterm_nothr0_sample = join(treeterm_nothr0_dir, '{rep}_{cond}')
treeterm_thr_sample = join(treeterm_thr_dir, '{rep}_{cond}')
term_sample = join(term_orig_dir, '{rep}_{cond}')

treeterm_thr_group = join(treeterm_thr_sample, 'groups.txt')
treeterm_nothr_group = join(treeterm_nothr_sample, 'groups.txt')
treeterm_nothr0_group = join(treeterm_nothr0_sample, 'groups.txt')
treeterm_mean_inf0_group = join(treeterm_mean_inf0_dir, 'cluster_nwk.txt')
term_group =  join(term_sample, 'groups.txt')

treeterm_nothr0_clust = join(treeterm_nothr0_dir, 'cluster_nwk.txt')
treeterm_thr_clust = join(treeterm_thr_dir, 'cluster_nwk.txt')
treeterm_nothr_clust = join(treeterm_nothr_dir, 'cluster_nwk.txt')
term_clust = join(term_sample, 'clusters.txt')

##
# Treeterm outputs
##

rule final_outputs:
    input:
        treeterm_thr_clust,
        treeterm_nothr_clust,
        treeterm_nothr0_clust,
        treeterm_mean_inf0_group,
        expand(term_clust, rep = REPLICATES, cond = [1,2])


### 
# 1) Expression simulation
###

rule sim_expression:
    input:
        fasta = TXP_FASTA_FP,
        txdb = TXDB_FP,
        tpms = CORTEX_TPM_FP,
        disp = MEANDISP_FP,
        gc = GC_PARAMS,
    output:
        txps = out_txps_fa_fp,
        expr = expr_rda_fp,
    params:
        fc_min=FC_MIN,
        fc_max=FC_MAX,
        frag_len=FRAG_LEN,
        lib_size=LIB_SIZE,
    shell:
        '''
            Rscript --vanilla scripts/simulate_brain_expression.R \
                {params.fc_min} \
                {params.fc_max} \
                {params.frag_len} \
                {params.lib_size} \
                {SEED} \
                {input.fasta} \
                {input.txdb} \
                {input.tpms} \
                {input.disp} \
                {input.gc} \
                {output.txps} \
                {output.expr}
        '''

### 
# 2) Read simulation
###

rule all_reads:
    input:
        expand(expand(read_shuff_fasta_file_rgen, cond=[1,2], read=[1,2]), rep=REPLICATES)

rule shuf_reads:
    input:
        shuffle_script= 'scripts/shuffle.sh',
        wrapper= 'scripts/shuffle_wrapper.sh',
        reads=expand(read_fasta_file_rgen,  cond=[1,2], read=[1,2]),
    params:
        read_path = read_rep_dir,
    output:
        expand(read_shuff_fasta_file_rgen, cond=[1,2], read=[1,2])
    shell:
        '''
            bash {input.wrapper} {input.shuffle_script} {params.read_path}
        '''

rule all_reads_:
    input:
        expand(expand(read_fasta_file_rgen, cond=[1,2], read=[1,2]), rep = REPLICATES)

rule make_reads:
    input:
        expr=expr_rda_fp,
        fasta=out_txps_fa_fp,
    output:
        expand(read_fasta_file_rgen,  cond=[1,2], read=[1,2])
    params:
        out_dir = reads_dir,
    shell:
        '''
            Rscript --vanilla scripts/simulate_reads.R {wildcards.rep} {input.fasta} {params.out_dir} {input.expr}
        '''

###
# 3. Salmon Quant
##

rule run_salmon:
    input:
        expand(sal_quant_dir, rep=REPLICATES, cond=[1,2])

rule _run_salmon:
    input:
        inp_fastq1 = read_fasta_file_pe1,
        inp_fastq2 = read_fasta_file_pe2,
        index = join(SAL_IDX_PATH, 'pos.bin')
    output:
        out_dir=directory(sal_quant_dir)
    resources: cpus=8, mem=32000
    params:
        #partition = get_partition('run_salmon'),
        sal_path = SAL_PATH,
        sal_ind = SAL_IDX_PATH,
        nrep = nrep,
        tf = tf
    run:
        if post_type == 'gibbs':
            if USE_GC:
                shell('{params.sal_path} quant -i {params.sal_ind} -l A -p {resources.cpus} --gcBias \
                --numGibbsSamples {params.nrep} --thinningFactor {tf} -d \
                -o {output.out_dir} -1 {input.inp_fastq1} -2 {input.inp_fastq2}')
            else:
                shell('{params.sal_path} quant -i {params.sal_ind} -l A -p {resources.cpus} \
                --numGibbsSamples {params.nrep} --thinningFactor {tf} -d \
                -o {output.out_dir} -1 {input.inp_fastq1} -2 {input.inp_fastq2}')
            
        else:
            if USE_GC:
                shell('{params.sal_path} quant -i {params.sal_ind} -l A -p {resources.cpus} --gcBias \
                --numBootstraps {params.nrep} -d \
                -o {output.out_dir} -1 {input.inp_fastq1} -2 {input.inp_fastq2}')
            else:
                shell('{params.sal_path} quant -i {params.sal_ind} -l A -p {resources.cpus} \
                --numBootstraps {params.nrep} -d \
                -o {output.out_dir} -1 {input.inp_fastq1} -2 {input.inp_fastq2}')

# rule build_sal_ind:
#     input:ancient(txp_ref)
#     output:join(sal_ind_path, 'pos.bin')
#     params:
#         out_ind_path = sal_ind_path,
#         #partition = get_partition('build_sal_ind'),
#         sal_path = config['sal_path']
#     resources:cpus=10, mem=32000
#     run:
#         if keep_dup:
#             shell('{params.sal_path} index --keepDuplicates -p {resources.cpus} -t {input} -i {params.out_ind_path}')
#         else:
#             shell('{params.sal_path} index -p {resources.cpus} -t {input} -i {params.out_ind_path}')

##
# 4. TreeTerminus and terminus
##

rule run_term_collapse:
    input:
        inp_term = expand(term_group, rep = REPLICATES, cond = [1,2]),
        inp_sal = expand(sal_quant_dir, rep = REPLICATES, cond = [1,2])
    output:
        expand(term_clust,  rep = REPLICATES, cond = [1,2])
    params:
        term_path = config['term_path'],
        output = term_orig_dir,
        #partition = get_partition('run_term_collapse')
    shell:
        '''
            {params.term_path} collapse -c 0.5 -d {input.inp_sal} -o {params.output}
        '''
rule run_term_group:
    input:expand(term_group, rep = REPLICATES, cond = [1,2])

rule _run_term_group:
    input:sal_quant_dir
    output:term_group
    params:
        term_path = config['term_path'], 
        output = term_orig_dir,
        #partition = get_partition('_run_term_group')
    shell:
        '''
          {params.term_path} group -m 0.1 --tolerance 0.001 -d {input} -o {params.output}
        '''

rule run_treeterm_mean_inf0:
    input:
        expand(sal_quant_dir, rep = REPLICATES, cond = [1,2])
    output:
        treeterm_mean_inf0_group
    params:
        input = sal_dir,
        treeterm_path = config['treeterm_path'], 
        output = treeterm_mean_inf0_dir,
    shell:
        '''
            {params.treeterm_path} group -m 0.1 --tolerance 0.001 \
            -d {params.input} -o {params.output}
        '''

rule run_treeterm_cons_nothr0:
    input:
        group_files = expand(treeterm_nothr0_group, rep = REPLICATES, cond = [1,2]),
        quant_files = expand(sal_quant_dir, rep = REPLICATES, cond = [1,2])
    output:treeterm_nothr0_clust
    params:
        treeterm_path = config['treeterm_path'],
        input = sal_dir,
        output = treeterm_nothr0_dir,
    shell:
        '''
            {params.treeterm_path} consensus -d {params.input} -o {params.output}
        '''        
rule run_treeterm_group_nothr0:
    input:
        expand(treeterm_nothr0_group, rep = REPLICATES, cond = [1,2])
rule _run_term_group_nothr0:
    input:
        sal_quant_dir
    output:
        treeterm_nothr0_group
    params:
        treeterm_path = config['treeterm_path'],
        output = treeterm_nothr0_dir,

    shell:
        '''
            {params.treeterm_path} group -d {input} -o {params.output} --mean_inf false
        '''

rule run_treeterm_cons_nothr:
    input:
        group_files = expand(treeterm_nothr_group, rep = REPLICATES, cond = [1,2]),
        quant_files = expand(sal_quant_dir, rep = REPLICATES, cond = [1,2])
    output:treeterm_nothr_clust
    params:
        treeterm_path = config['treeterm_path'],
        input = sal_dir,
        output = treeterm_nothr_dir,
    shell:
        '''
            {params.treeterm_path} consensus -d {params.input} -o {params.output}
        '''

rule run_treeterm_group_nothr:
    input:
        expand(treeterm_nothr_group, rep = REPLICATES, cond = [1,2])
rule _run_treeterm_group_nothr:
    input:
        sal_quant_dir
    output:
        treeterm_nothr_group
    params:
        treeterm_path = config['treeterm_path'],
        output = treeterm_nothr_dir,
    shell:
        '''
            {params.treeterm_path} group \
                -d {input} -o {params.output} \
                --mean_inf false --thr false --inf_perc 0.25
        '''

rule run_treeterm_cons_thr:
    input:
        group_files = expand(treeterm_thr_group, rep = REPLICATES, cond = [1,2]),
        quant_files = expand(sal_quant_dir, rep = REPLICATES, cond = [1,2])
    output:treeterm_thr_clust
    params:
        treeterm_path = config['treeterm_path'],
        input = sal_dir,
        output = treeterm_thr_dir,
    shell:
        '''
            {params.treeterm_path} consensus \
                -d {params.input} -o {params.output}
        '''

rule run_treeterm_group_thr:
    input:
        expand(treeterm_thr_group, rep = REPLICATES, cond = [1,2])
rule _run_treeterm_group_thr:
    input:
        sal_quant_dir
    output:
        treeterm_thr_group
    params:
        treeterm_path = config['treeterm_path'],
        output = treeterm_thr_dir,

    shell:
        '''
            {params.treeterm_path} group \
                -d {input} -o {params.output} \
                --mean_inf false --thr true --inf_perc 0.25
        '''
