'''
Simulates motif sequences for a subfamily based on a single consensus or a group of ortholog references (starting points for simulated mutations)
Requires:
	-fastaLike file for element-reference alignment (elements to be mutated)
	-association file for matching elements to reference
	-transition/transversion rates file
	-motif annotation file (FIMO output, SINGLE MOTIF ONLY, used to motif locations in consensus/references)
Usage: python3 create_simulated_subfamily_motifs.py <fastaLike alignment file> <association file> <transition/transversion rate file> <motif annotation file> <output file basename>
'''

import sys
import random
import statistics as stats

if len(sys.argv) != 6:
	sys.exit(__doc__)

align_file = sys.argv[1]
assoc_file = sys.argv[2]
rates_file = sys.argv[3]
motif_file = sys.argv[4]
output_basename = sys.argv[-1]

#Function that returns the reverse complement of a sequence
def reverse_complement(seq):
	comp = {'A':'T', 'T':'A', 'C':'G', 'G':'C'}
	reverse_seq = seq[::-1]
	rev_comp = ''
	for base in reverse_seq:
		rev_comp += comp[base]
	return rev_comp

#Dictionary that returns the corresponding base following transition
transition_di = {'A':'G', 'G':'A', 'T':'C', 'C':'T'}

#Dictionary that returns the possible bases following transversion (use random.choice to randomly pick one)
transversion_di = {'A':'TC', 'G':'TC', 'T':'AG', 'C':'AG'}

#Get reference-element associations
assoc_di = {}
with open(assoc_file, 'r') as f:
	for line in f:
		fields = line.rstrip('\n').split('\t')
		ref = fields[0]
		for element in fields[1:]:
			if element == '.':
				continue
			if element in assoc_di:
				assoc_di[element].append(ref)
			else:
				assoc_di[element] = [ref]

#Get transition/transversion rates
species_list = []
rates_di = {}
with open(rates_file, 'r') as f:
	header = f.readline().strip()
	header_fields = header.split('\t')
	if header_fields[0] == 'Species': #Species separated transition/transversion rates
		separate_species = True
		for line in f:
			fields = line.rstrip().split('\t')
			species = fields[0]
			rate_type = fields[1]
			rate = float(fields[2])
			if species not in species_list:
				rates_di[species] = []
				species_list.append(species)
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di[species].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di[species].append(rate)
	elif header_fields[0] == 'Type': #No separate species, just one set of overall transition/transversion rates
		species_list.append('overall')
		rates_di['overall'] = []
		separate_species = False
		for line in f:
			fields = line.rstrip().split('\t')
			rate_type = fields[0]
			rate =float(fields[1])
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di['overall'].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di['overall'].append(rate)
#Check that there is only one transition/transversion rate for each species
for species in species_list:
	if len(rates_di[species]) != 2:
		print(species + ' has more than 1 transition/transversion rate (should only have one of each) in ' + rates_file)
		sys.exit(__doc__)

#Get motif annotations
motif_di = {}
motif_starts = [] #Used to get the location of the most likely orthologous motif site if more than 1 of the motif in the same element
with open(motif_file, 'r') as f:
	header = f.readline().rstrip()
	prev_motif = ''
	for line in f:
		if line.startswith('#'):
			continue
		elif line.strip() == '':
			continue
		fields = line.rstrip().split('\t')
		motif = (fields[1] + ' ' + fields[0]).strip()
		if prev_motif == '':
			pass
		elif prev_motif != motif:
			print('More than 1 motif (' + prev_motif + ' and ' + motif + ') in ' + motif_file)
			sys.exit(__doc__)
		element = fields[2]
		start = int(fields[3]) - 1 #Change to 0-base
		stop = int(fields[4]) - 1 #Change to 0-base
		strand = fields[5]
		if strand == '+':
			motif_seq = fields[9].upper()
		else:
			motif_seq = reverse_complement(fields[9].upper())
		if element in motif_di:
			motif_di[element].append([start, stop, strand, motif_seq])
		else:
			motif_di[element] = [[start, stop, strand, motif_seq]]
		motif_starts.append(start)
		prev_motif = motif
#Get the most likely motif start based on abundance
most_common_motif_start = stats.mode(motif_starts)
#Get motif length
motif_len = len(motif_seq)

#Randomly mutate motif positions and surrounding bases using the given transition/transversion rates
#Output mutated sequneces to fastaLike (with unchanged references, for motif conservation) and fasta (for motif scanning) files
list_unanalyzed_elements = []
outputted_references = []
with open(align_file, 'r') as f:
	with open(output_basename + '.fastaLike', 'w') as o1:
		with open(output_basename + '.fa', 'w') as o2:
			for line in f:
				if line.startswith('>'):
					name = line.lstrip('>').rstrip()
					if name not in assoc_di: #element is not associated with any reference in the association file
						list_unanalyzed_elements.append([name, 'no reference'])
						continue #Skip because there is no known reference point for the alignment
					elif len(assoc_di[name]) > 1: #more than one reference for the same element
						list_unanalyzed_elements.append([name, 'multiple references'])
						continue #Skip because it's unknown which is the reference point to use, don't want to double count the same element for motif conservation
					reference = assoc_di[name][0]
					if separate_species:
						if ';' not in name: #No ";" to separate name from species => no species info
							list_unanalyzed_elements.append([name, 'no species'])
							continue #Skip if no species info because can't pick a transition/transversion rate to use
						species = name.split(';')[-1]
						if species not in species_list: #If no transition/transversion rate available for species, skip
							list_unanalyzed_elements.append([name, 'no species transition/transversion rate'])
							continue
					else:
						species = 'overall'
					transition_rate = rates_di[species][0]
					transversion_rate = rates_di[species][1]
					conserved_rate = 1 - transition_rate - transversion_rate
					species_index = species_list.index(species)
					seq = f.readline().rstrip().upper()
					align = f.readline().rstrip().upper()
					if reference in motif_di:
						reference_motif_info = motif_di[reference]
					else: #no motif in the reference, skip because there is no motif to conserve
						list_unanalyzed_elements.append([name, 'no motif in reference'])
						continue
					if len(reference_motif_info) > 1: #more than 1 motif found in the reference
						motif_starts = []
						for i in range(len(reference_motif_info)): #Get all motif start positions
							motif_starts.append(reference_motif_info[i][0])
						best_index = 0
						best_dist = len(align.replace('-', ''))
						for i in range(len(motif_starts)): #Get the motif with the lowest distance from the most common motif start position
							motif_dist = abs(motif_starts[i] - most_common_motif_start)
							if motif_dist < best_dist:
								best_dist = motif_dist
								best_index = i
							#Motif info is by default ordered by p-value, so if there are multiple motifs with the same distance, choose the earliest instance
						motif_index = best_index
					elif len(reference_motif_info) == 1: #only 1 motif found in the reference
						motif_index = 0
					reference_motif_start = reference_motif_info[motif_index][0]
					reference_motif_stop = reference_motif_info[motif_index][1]
					reference_motif_strand = reference_motif_info[motif_index][2]
					reference_motif_seq = reference_motif_info[motif_index][3]
					#Check that reference has the expected motif sequence in the alignment and
					#that the element-reference alignment doesn't have gaps in the motif region (would likely affect motif call if there is insertion/deletion)
					reference_position = 0
					element_position = 0
					align_gap = False
					gap_reason = ''
					align_ref_motif_seq = ''
					element_motif_start = 0
					element_motif_stop = 0
					for i in range(len(align)):
						if align[i] == '-':
							if seq[i] != '-':
								element_position += 1
							if reference_position >= reference_motif_start and reference_position <= reference_motif_stop: #element insertion in the motif region
								align_gap = True
								gap_reason = 'insert in element at reference position ' + str(reference_position)
								break
							else:
								continue
						else:
							if reference_position == reference_motif_start:
								element_motif_start = element_position
							elif reference_position == reference_motif_stop:
								element_motif_stop = element_position
							if seq[i] != '-':
								element_position += 1
							if reference_position >= reference_motif_start and reference_position <= reference_motif_stop: #in motif region
								if seq[i] == '-': #element deletion in the motif region
									align_gap = True
									gap_reason = 'deletion in element at reference position ' + str(reference_position)
									break
								else:
									align_ref_motif_seq += align[i].upper()
									reference_position += 1
							else: #not in motif region
								reference_position += 1
					if align_gap: #if there was a gap in the motif region, skip the element
						list_unanalyzed_elements.append([name, gap_reason])
						continue
					if align_ref_motif_seq != reference_motif_seq: #motif sequence from alignment and motif file don't agree
						list_unanalyzed_elements.append([name, 'motif sequence from alignment and motif file don\'t agree'])
						continue
					#At this point, no gaps in motif region so the element becomes a possible case for motif conservation
					#Simulate mutations in motif region (here defined as motif position +/- (motif length - 1))
					randomized_seq = ''
					randomization_start = element_motif_start - (motif_len - 1)
					randomization_stop = element_motif_stop + (motif_len - 1)
					transition_thresh = transition_rate
					transversion_thresh = transition_rate + transversion_rate
					#mut_motif_section = ''
					#ref_motif_section = ''
					for i in range(len(seq)):
						if i < randomization_start:
							randomized_seq += seq[i]
						elif i > randomization_stop:
							randomized_seq += seq[i]
						elif i >= randomization_start and i <= randomization_stop:
							reference_base = align[i]
							#ref_motif_section += reference_base
							if reference_base == '-': #If gap in reference, keep it as gap (possible when just outside motif in surrounding bases)
								randomized_seq += reference_base
								#mut_motif_section += reference_base
							elif reference_base.strip('ACGT') != '': #If reference base is a non-standard nucleotide, just use the reference base as is
								randomized_seq += reference_base
								#mut_motif_section += reference_base
							else: #If reference has a base, randomly decide if transition, transversion, or conservation
								random_roll = random.random()
								if random_roll < transition_thresh: #If transition, add the corresponding transitioned base to the randomized seq
									randomized_seq += transition_di[reference_base]
									#mut_motif_section += transition_di[reference_base]
								elif random_roll < transversion_thresh: #If transversion, randomly decide which transversion to make
									to_add_random_base = random.choice(transversion_di[reference_base])
									randomized_seq += to_add_random_base
									#mut_motif_section += to_add_random_base
								else: #If conservation, add the unchanged reference base
									randomized_seq += reference_base
									#mut_motif_section += reference_base
					#Write simulated sequences to output
					o1.write('>' + name + '\n')
					o1.write(randomized_seq + '\n')
					o1.write(align + '\n')
					o2.write('>' + name + '\n')
					o2.write(randomized_seq.replace('-', '') + '\n')
					if reference not in outputted_references:
						o2.write('>' + reference + '\n')
						o2.write(align.replace('-', '') + '\n')
						outputted_references.append(reference)
					#print('>' + name + '\n' + mut_motif_section + '\n' + ref_motif_section)

print('Unanalyzed elements and reason')
for instance in list_unanalyzed_elements:
	print(instance[0] + '\t' + instance[1])

