'''
Finds the regions for each SHARPR seq that goes above the mean+3*stdev of Basal control.
Requires the mean+stdev Basal control file to calculate the cutoff at each position. Smoothing of values is recommended (average of current position +/- 5 bases around)
Outputs the element name, start, stop (exclusive) in tab-delimited format.
Usage: python3 find_SHARPR_peaks_above_confidence_interval.py <input SHARPR seq file> <Basal mean+stdev file> <output file>
'''

import sys

if len(sys.argv) != 4:
	sys.exit(__doc__)

num_stdev = 3 #Can be changed to whatever number of stdev above mean threshold you want
cutoffs = []
with open(sys.argv[2], 'r') as f:
	means = f.readline().rstrip('\n').split('\t')[1:]
	means_plus = f.readline().rstrip('\n').split('\t')[1:]
	for i in range(len(means)):
		cutoffs.append(float(means[i]) + num_stdev*(float(means_plus[i]) - float(means[i])))

order = []
di = {}
with open(sys.argv[1], 'r') as f:
	for line in f:
		if ">" in line:
			name = line.lstrip('>element_').rstrip('\n')
			scores = f.readline().rstrip('\n').split('\t')
			seq = []
			for score in scores:
				if score == 'NA':
					continue
				else:
					seq.append(float(score))
			di[name] = seq
			order.append(name)

with open(sys.argv[3], 'w') as o:
	for name in order:
		seq = di[name]
		region = False
		for i in range(len(seq)):
			if seq[i] > cutoffs[i]:
				if region:
					continue
				else:
					start = i
					region = True
			else:
				if region:
					stop = i
					o.write(name + '\t' + str(start) + '\t' + str(stop) + '\n')
					region = False
				else:
					continue
		if region:
			stop = len(seq)
			o.write(name + '\t' + str(start) + '\t' + str(stop) + '\n')

