'''
Calculates total motif conservation probability (mean) and total expected number of motif conservations.
Also outputs the observed number of motif conservations and observed proportion.
Separates based on genome if applicable (genome included in the last part of the fasta name with ";" separator).
Requires:
	-fastaLike file for element-reference alignment
	-association file for matching elements to reference
	-transition/transversion rates file
	-motif annotation file (FIMO output, SINGLE MOTIF ONLY)
	-kmers motif file (FIMO output of all kmers of same motif length scanned for motif presence, SAME MOTIF AS MOTIF ANNOTATION FILE)
Usage: python3 calc_neutral_expectation_motif_conservation_individual.py <fastaLike alignment file> <association file> <transition/transversion rate file> <motif annotation file> <kmers file> <output file>
'''

import sys
import statistics as stats

if len(sys.argv) != 7:
	sys.exit(__doc__)

align_file = sys.argv[1]
assoc_file = sys.argv[2]
rates_file = sys.argv[3]
motif_file = sys.argv[4]
kmers_file = sys.argv[5]
out_file = sys.argv[-1]

#Function that returns the reverse complement of a sequence
def reverse_complement(seq):
	comp = {'A':'T', 'T':'A', 'C':'G', 'G':'C'}
	reverse_seq = seq[::-1]
	rev_comp = ''
	for base in reverse_seq:
		rev_comp += comp[base]
	return rev_comp

#Get reference-element associations
assoc_di = {}
with open(assoc_file, 'r') as f:
	for line in f:
		fields = line.rstrip('\n').split('\t')
		ref = fields[0]
		for element in fields[1:]:
			if element == '.':
				continue
			if element in assoc_di:
				assoc_di[element].append(ref)
			else:
				assoc_di[element] = [ref]

#Get transition/transversion rates
species_list = []
rates_di = {}
with open(rates_file, 'r') as f:
	header = f.readline().strip()
	header_fields = header.split('\t')
	if header_fields[0] == 'Species': #Species separated transition/transversion rates
		separate_species = True
		for line in f:
			fields = line.rstrip().split('\t')
			species = fields[0]
			rate_type = fields[1]
			rate = float(fields[2])
			if species not in species_list:
				rates_di[species] = []
				species_list.append(species)
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di[species].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di[species].append(rate)
	elif header_fields[0] == 'Type': #No separate species, just one set of overall transition/transversion rates
		species_list.append('overall')
		rates_di['overall'] = []
		separate_species = False
		for line in f:
			fields = line.rstrip().split('\t')
			rate_type = fields[0]
			rate =float(fields[1])
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di['overall'].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di['overall'].append(rate)
#Check that there is only one transition/transversion rate for each species
for species in species_list:
	if len(rates_di[species]) != 2:
		print(species + ' has more than 1 transition/transversion rate (should only have one of each) in ' + rates_file)
		sys.exit(__doc__)

#Get motif annotations
motif_di = {}
motif_starts = [] #Used to get the location of the most likely orthologous motif site if more than 1 of the motif in the same element
with open(motif_file, 'r') as f:
	header = f.readline().rstrip()
	prev_motif = ''
	for line in f:
		if line.startswith('#'):
			continue
		elif line.strip() == '':
			continue
		fields = line.rstrip().split('\t')
		motif = (fields[1] + ' ' + fields[0]).strip()
		if prev_motif == '':
			pass
		elif prev_motif != motif:
			print('More than 1 motif (' + prev_motif + ' and ' + motif + ') in ' + motif_file)
			sys.exit(__doc__)
		element = fields[2]
		start = int(fields[3]) - 1 #Change to 0-base
		stop = int(fields[4]) - 1 #Change to 0-base
		strand = fields[5]
		if strand == '+':
			motif_seq = fields[9].upper()
		else:
			motif_seq = reverse_complement(fields[9].upper())
		if element in motif_di:
			motif_di[element].append([start, stop, strand, motif_seq])
		else:
			motif_di[element] = [[start, stop, strand, motif_seq]]
		motif_starts.append(start)
		prev_motif = motif
#Get the most likely motif start based on abundance
try:
	most_common_motif_start = stats.mode(motif_starts)
except stats.StatisticsError: #Two motif starts that are equally as likely, take the mean of the motif starts if they are within 10bp of each other
	motif_start_positions = set(motif_starts)
	start_counts = []
	for i in motif_start_positions:
		start_counts.append(motif_starts.count(i))
	max_indices = []
	for index in range(len(start_counts)):
		if start_counts[index] == max(start_counts):
			max_indices.append(index)
	most_common_motif_starts = []
	count = 0
	for i in motif_start_positions:
		if i in max_indices:
			most_common_motif_starts.append(i)
		count += 1
	if len(most_common_motif_starts) == 2: #Only deal with the situation where there are two motif starts that are equally as likely
		start1 = most_common_motif_starts[0]
		start2 = most_common_motif_starts[1]
		if (start1 > start2 and start1 <= start2+10) or (start2 > start1 and start2 <= start1+10): #Starts are within 10bp of each other
			most_common_motif_start = stats.mean(most_common_motif_starts)
		else:
			print('Two most common motif starts (' + str(start1) + ' and ' + str(start2) + ') are more than 10bp apart')
			sys.exit()
	else: #More than two motif starts that are equally as likely
		print('More than two most common motif starts')
		sys.exit()

#Get all kmers that are called motifs
matching_kmers = []
with open(kmers_file, 'r') as f:
	header = f.readline()
	for line in f:
		if line.startswith('#'):
			continue
		elif line.strip() == '':
			continue
		fields = line.rstrip().split('\t')
		kmer_motif = (fields[1] + ' ' + fields[0]).strip()
		if kmer_motif != motif:
			print('kmer motif (' + kmer_motif + ') does not match the motif annotation file motif (' + motif + ')')
			sys.exit(__doc__)
		strand = fields[5]
		if strand == '+':
			kmer_seq = fields[9].upper()
		elif strand == '-':
			kmer_seq = reverse_complement(fields[9].upper())
		else:
			print(fields[2] + ' in ' + kmers_file + ' does not have strand info')
			sys.exit(__doc__)
		if kmer_seq not in matching_kmers: #Only evaluate a specific kmer sequence once
			matching_kmers.append(kmer_seq)

#Calculate the probability of motif conservation for each reference-element alignment in the fastaLike input file
#Adds up the probabilities for each species (or overall if no separate species transition/transversion rates) to get the species motif conservation expectation
probability_motif_conservation = []
total_possible_motif_conservation = []
observed_motif_conservation = []
for species in species_list:
	probability_motif_conservation.append(0)
	total_possible_motif_conservation.append(0)
	observed_motif_conservation.append(0)
ref_motif_probabilities = []
for species in species_list:
	ref_motif_probabilities.append({})
list_unanalyzed_elements = []
with open(align_file, 'r') as f:
	for line in f:
		if line.startswith('>'):
			name = line.lstrip('>').rstrip()
			if name not in assoc_di: #element is not associated with any reference in the association file
				list_unanalyzed_elements.append([name, 'no reference'])
				continue #Skip because there is no known reference point for the alignment
			elif len(assoc_di[name]) > 1: #more than one reference for the same element
				list_unanalyzed_elements.append([name, 'multiple references'])
				continue #Skip because it's unknown which is the reference point to use, don't want to double count the same element for motif conservation
			reference = assoc_di[name][0]
			if separate_species:
				if ';' not in name: #No ";" to separate name from species => no species info
					list_unanalyzed_elements.append([name, 'no species'])
					continue #Skip if no species info because can't pick a transition/transversion rate to use
				species = name.split(';')[-1]
				if species not in species_list: #If no transition/transversion rate available for species, skip
					list_unanalyzed_elements.append([name, 'no species transition/transversion rate'])
					continue
			else:
				species = 'overall'
			transition_rate = rates_di[species][0]
			transversion_rate = rates_di[species][1]
			conserved_rate = 1 - transition_rate - transversion_rate
			species_index = species_list.index(species)
			seq = f.readline().rstrip().upper()
			align = f.readline().rstrip().upper()
			if reference in motif_di:
				reference_motif_info = motif_di[reference]
			else: #no motif in the reference, skip because there is no motif to conserve
				list_unanalyzed_elements.append([name, 'no motif in reference'])
				continue
			if len(reference_motif_info) > 1: #more than 1 motif found in the reference
				motif_starts = []
				for i in range(len(reference_motif_info)): #Get all motif start positions
					motif_starts.append(reference_motif_info[i][0])
				best_index = 0
				best_dist = len(align.replace('-', ''))
				for i in range(len(motif_starts)): #Get the motif with the lowest distance from the most common motif start position
					motif_dist = abs(motif_starts[i] - most_common_motif_start)
					if motif_dist < best_dist:
						best_dist = motif_dist
						best_index = i
					#Motif info is by default ordered by p-value, so if there are multiple motifs with the same distance, choose the earliest instance
				motif_index = best_index
			elif len(reference_motif_info) == 1: #only 1 motif found in the reference
				motif_index = 0
			reference_motif_start = reference_motif_info[motif_index][0]
			reference_motif_stop = reference_motif_info[motif_index][1]
			reference_motif_strand = reference_motif_info[motif_index][2]
			reference_motif_seq = reference_motif_info[motif_index][3]
			#Check that reference has the expected motif sequence in the alignment and
			#that the element-reference alignment doesn't have gaps in the motif region (would likely affect motif call if there is insertion/deletion)
			reference_position = 0
			element_position = 0
			align_gap = False
			gap_reason = ''
			align_ref_motif_seq = ''
			element_motif_start = 0
			element_motif_stop = 0
			for i in range(len(align)):
				if align[i] == '-':
					if seq[i] != '-':
						element_position += 1
					if reference_position >= reference_motif_start and reference_position <= reference_motif_stop: #element insertion in the motif region
						align_gap = True
						gap_reason = 'insert in element at reference position ' + str(reference_position)
						break
					else:
						continue
				else:
					if reference_position == reference_motif_start:
						element_motif_start = element_position
					elif reference_position == reference_motif_stop:
						element_motif_stop = element_position
					if seq[i] != '-':
						element_position += 1
					if reference_position >= reference_motif_start and reference_position <= reference_motif_stop: #in motif region
						if seq[i] == '-': #element deletion in the motif region
							align_gap = True
							gap_reason = 'deletion in element at reference position ' + str(reference_position)
							break
						else:
							align_ref_motif_seq += align[i].upper()
							reference_position += 1
					else: #not in motif region
						reference_position += 1
			if align_gap: #if there was a gap in the motif region, skip the element
				list_unanalyzed_elements.append([name, gap_reason])
				continue
			if align_ref_motif_seq != reference_motif_seq: #motif sequence from alignment and motif file don't agree
				list_unanalyzed_elements.append([name, 'motif sequence from alignment and motif file don\'t agree'])
				continue
			#At this point, no gaps in motif region so the element becomes a possible case for motif conservation
			total_possible_motif_conservation[species_index] += 1
			#Use element motif start and stop positions from alignment to see if there is a motif annotated that overlaps with the alignment position
			if name in motif_di: #if element has a motif
				element_motif_info = motif_di[name]
				for i in range(len(element_motif_info)):
					if element_motif_start <= element_motif_info[i][0] and element_motif_stop >= element_motif_info[i][0]: #motif annotation overlaps with expected location, case1
						observed_motif_conservation[species_index] += 1
						break
					elif element_motif_start >= element_motif_info[i][0] and element_motif_start <= element_motif_info[i][1]: #motif annotation overlaps with expected location, case2
						observed_motif_conservation[species_index] += 1
						break
					#if any of the above overlap cases, only increase observed motif conservation once
			else: #if element does not have the motif
				pass #do not add anything to observed motif conservation
			#Calculate motif conservation probability for the reference motif, if not already calculated
			if reference in ref_motif_probabilities[species_index]: #if already calculated motif conservation probability, just add that to the expected probability of motif conservation
				probability_motif_conservation[species_index] += ref_motif_probabilities[species_index][reference]
			else: #if not already calculated, calculate motif conservation probability, add to ref_motif_probabilities, and add to expected probability of motif conservation
				total_probability = 0
				for kmer_seq in matching_kmers:
					num_transitions = 0
					num_transversions = 0
					num_conserved = 0
					for i in range(len(kmer_seq)):
						if kmer_seq[i] == reference_motif_seq[i]:
							num_conserved += 1
						elif reference_motif_seq[i] == 'A' and kmer_seq[i] == 'G':
							num_transitions += 1
						elif reference_motif_seq[i] == 'G' and kmer_seq[i] == 'A':
							num_transitions += 1
						elif reference_motif_seq[i] == 'C' and kmer_seq[i] == 'T':
							num_transitions += 1
						elif reference_motif_seq[i] == 'T' and kmer_seq[i] == 'C':
							num_transitions += 1
						else:
							num_transversions += 1
					probability = conserved_rate**num_conserved * transition_rate**num_transitions * transversion_rate**num_transversions
					total_probability += probability
				ref_motif_probabilities[species_index][reference] = total_probability
				probability_motif_conservation[species_index] += total_probability

with open(out_file, 'w') as o:
	o.write('Motif: ' + motif + '\n')
	o.write('Species\tTransition_rate\tTransversion_rate\n')
	for species in species_list:
		species_index = species_list.index(species)
		if total_possible_motif_conservation[species_index] == 0: #no elements to calculate motif conservation for the species, skip
			continue
		else:
			o.write(species + '\t' + '\t'.join([str(x) for x in rates_di[species]]) + '\n')
	o.write('\n')
	o.write('Species\tTotal_possible_elements\tProbability_conserved\tExpected_conserved\tObserved_conserved\tObserved_proportion\n')
	for species in species_list:
		species_index = species_list.index(species)
		if total_possible_motif_conservation[species_index] == 0: #no elements to calculate motif conservation for the species, skip
			continue
		line_data = []
		line_data.append(species)
		line_data.append(str(total_possible_motif_conservation[species_index]))
		all_reference_motif_probabilities = []
		for reference in ref_motif_probabilities[species_index]:
			all_reference_motif_probabilities.append(ref_motif_probabilities[species_index][reference])
		average_expected_conservation = stats.mean(all_reference_motif_probabilities)
		line_data.append(str(average_expected_conservation))
		line_data.append(str(probability_motif_conservation[species_index]))
		line_data.append(str(observed_motif_conservation[species_index]))
		observed_proportion = observed_motif_conservation[species_index]/total_possible_motif_conservation[species_index]
		line_data.append(str(observed_proportion))
		o.write('\t'.join(line_data) + '\n')

print('Unanalyzed elements and reason')
for instance in list_unanalyzed_elements:
	print(instance[0] + '\t' + instance[1])

