'''
Gets SHARPR sum for tiles
Usage: python3 correlate_sharprSum_tiles.py <input tile basalNorm file> <input SHARPR seq file> <output file>
'''

import sys
import math
import statistics as stats

if len(sys.argv) != 4:
	sys.exit(__doc__)

di_tile = {}
elements = []
with open(sys.argv[1], 'r') as f:
	for line in f:
		fields = line.rstrip('\n').split('\t')
		if 'reverse' in fields[0]:
			element = 'reverse' + fields[0].split('|')[0]
		else:
			element = fields[0].split('|')[0]
		logFC = math.log2(stats.mean([float(x) for x in fields[1:]]))
		if element not in di_tile:
			di_tile[element] = []
			elements.append(element)
		tile = int(fields[0].split('|')[-1].split('_')[-1].split(';')[0])
		while len(di_tile[element]) < tile:
			di_tile[element].append('NA')
		di_tile[element].append(logFC)

di_sharpr = {}
with open(sys.argv[2], 'r') as f:
	for line in f:
		if line.startswith('>'):
			element = '_'.join(line.rstrip('\n').split('_')[1:]).replace('-full', '_full')
			seq = f.readline().rstrip('\n').split('\t')
			if element in elements:
				di_sharpr[element] = []
			else:
				continue
			tile_len = 160
			offset = 10
			offset_sums = []
			total = 0
			bp_tracker = 0
			for i in range(len(seq)):
				if seq[i] != 'NA':
					total += float(seq[i])
					bp_tracker += 1
					if bp_tracker % 10 == 0: #every 10 bp, add the total to offset sums and reset
						offset_sums.append(total)
						total = 0
			if bp_tracker % 10 != 0: #add last not full length offset sum
				offset_sums.append(total)
			num_tiles = math.ceil((bp_tracker - tile_len)/offset) + 1
			offsetsPerTile = int(tile_len/offset)
			for i in range(num_tiles):
				end = offsetsPerTile + i
				di_sharpr[element].append(sum(offset_sums[i:end]))
			if bp_tracker % 10 == 0: #There was a mistake in tile generation: it creates a 10bp deletion tile if the element length is a multiple of 10
				start = len(offset_sums) - int((tile_len/offset) - 1)
				di_sharpr[element].append(sum(offset_sums[start:]))

with open(sys.argv[-1], 'w') as o:
	o.write('TileName\tEnrichScore\tSHARPR\n')
	for element in elements:
		for i in range(len(di_tile[element])):
			tilename = element + '|tile_' + str(i)
			if di_tile[element][i] != 'NA':
				o.write(tilename + '\t' + str(di_tile[element][i]) + '\t' + str(di_sharpr[element][i]) + '\n')

