import pandas as pd
import numpy as np
import os

configfile: 'downsample.yaml'

DS = config.get('DS', ['10'])
GENOME_SIZE = config.get('GENOME_SIZE', 3.1e9)
MANIFEST = config.get('MANIFEST', 'manifest.tab')

manifest_df = pd.read_csv(MANIFEST, sep='\t', comment='#')

manifest_df = manifest_df.set_index(['SAMPLE', 'SEQ_TYPE'])


def find_fofn(wildcards):
	return manifest_df.at[(wildcards.sample, wildcards.seq_type), 'FOFN']

rule all:
	input:
		expand(expand('{seq_type}/{sample}/{seq_type}_{{ds}}X.fastq.gz', zip, sample=manifest_df.index.get_level_values('SAMPLE'), seq_type=manifest_df.index.get_level_values('SEQ_TYPE')), ds=DS)


rule sample:
	input:
		fofn = find_fofn
	output:
		regions = 'tmp/{sample}/{seq_type}_{ds}.tab'
	threads: 1
	resources:
		mem = 8,
		hrs = 72
	run:
		df = pd.DataFrame()
		
		with open(input.fofn, 'r') as infile:
			for line in infile:
				fai_df = pd.read_csv(line.rstrip()+'.fai', sep='\t', header=None, usecols=[0,1], names=['read_name', 'len'])
				fai_df['source'] = line.rstrip()
				df = df.append(fai_df)

		df = df.reset_index(drop=True)
		
		exp_cov = int(wildcards.ds) * GENOME_SIZE

		cov_df = df.sample(frac=1).reset_index(drop=True)

		cov_sum = np.cumsum(cov_df['len'])

		out_df = cov_df.iloc[cov_sum.loc[cov_sum <= exp_cov].index]

		out_df.to_csv(output.regions, sep='\t', index=False)


rule extract_reads:
	input:
		regions = rules.sample.output.regions
	output:
		reg = temp('tmp/{sample}/{seq_type}_{ds}_reg.tab'),
		reads = temp('{seq_type}/{sample}/{seq_type}_{ds}X.fastq')
	threads: 1
	resources:
		mem = 8,
		hrs = 72
	run:
		df = pd.read_csv(input.regions, sep='\t')
		for file in df['source'].unique():
			file_base = os.path.basename(file)
			shell(f'rsync -av {file}* {resources.tmpdir}')
			reg_df = df.loc[df['source'] == file]
			reg_df[['read_name']].to_csv(output.reg, sep='\t', header=False, index=False)
			shell(f'module load seqtk/1.3; samtools fqidx -r {output.reg} {resources.tmpdir}/{file_base} | seqtk seq -l0 >> {resources.tmpdir}/$( basename {output.reads} )')
			shell(f'rm {resources.tmpdir}/{file_base}')		
		shell(f'rsync -av {resources.tmpdir}/$( basename {output.reads} ) {output.reads}')





rule compress_and_index:
	input:
		reads = rules.extract_reads.output.reads
	output:
		reads = '{seq_type}/{sample}/{seq_type}_{ds}X.fastq.gz'
	threads: 1
	resources:
		mem = 8,
		hrs = 72
	shell:
		'''
		bgzip -c {input.reads} > {output.reads}
		samtools fqidx {output.reads}
		'''





