'''
Identifies small RNAs derived from sense-antisense duplexes with 2-nt, 3' overhangs
using alignment_table from tinyRNA as input.

Author: Tai Montgomery, Nov. 22, 2019, updated for tinyRNA output on 6/27/23

Command line usage:
python3 duplex_finder.py -i 'alignment_table' -o 'output_file'
'''

import argparse

def main():
	'''
	Calls functions to identify dicer products
	'''
	infile, outfile = arg_parse()
	duplex_finder(infile, outfile)
	

def arg_parse():
	'''
	Parses command line arguments
	'''
	parser = argparse.ArgumentParser()
	parser.add_argument('-i', '--infile', required=True, help="input file in csv format")	
	parser.add_argument('-o', '--outfile', required=True, help="output file")

	args = parser.parse_args()
	return args.infile, args.outfile

def csv_parser(infile):
	'''
	Parses features in csv to dictionary
	'''
	print(f'parsing csv: {infile}....', end = '', flush=True)
	data = {}
	
	with open(infile) as file:
		for line in file:
			line = line.rstrip()
			sequence,rawCount,normalizedCount,genomicHits,chrom,strand,start,end,mismatches,candidates,*features = line.split(',')
			seq_id = f"{chrom}_{start}_{end}_{strand}"
			hits = ''.join(features)	
			data[seq_id] = [sequence,rawCount,normalizedCount,genomicHits,chrom,strand,start,end,mismatches,candidates,hits]
	
	print('done.', flush=True)
	return data
	
		
def duplex_finder(infile, outfile):
	'''
	Identifies dicer duplexes and writes data to output
	'''
	d1 = csv_parser(infile)
	print(f"Finding dicer products {infile}....", end = '', flush=True)
	
	with open(infile) as input_file, open(outfile, 'w') as output_file:
		d2 = {} 
		for line in input_file:
			if 'Sequence' in line:
				output_file.write(line)
				continue
			
			line = line.rstrip()
			sequence,rawCount,normalizedCount,genomicHits,chrom,strand,start,end,mismatches,candidates,*features = line.split(',')
			hits = ''.join(features)
			
			if strand == '+':
				start = int(start)-2
				end = int(end)-2
				strand = '-'
			elif strand == '-':
				start = int(start)+2
				end = int(end)+2
				strand = '+'
			
			seq_id = f"{chrom}_{start}_{end}_{strand}"
			d2[seq_id] = [sequence,rawCount,normalizedCount,genomicHits,chrom,strand,start,end,mismatches,candidates,hits]
		
		duplexes = d1.keys() & d2.keys()
		
		for key in duplexes:
			output_file.write(f"{d1[key][0]},{d1[key][1]},{d1[key][2]},{d1[key][3]},{d1[key][4]},{d1[key][5]},{d1[key][6]},{d1[key][7]},{d1[key][8]},{d1[key][9]},{d1[key][10]}\n")
	
	print("done.",flush=True)
	return 1				

if __name__ == '__main__':
	main()