'''
Simulates the full sequences for a subfamily based on a single consensus or a group of ortholog references (starting points for simulated mutations)
Keeps gaps for the purpose of testing the same number of elements per motif (elements with gaps in motifs are not considered later)
Requires:
	-fastaLike file for element-reference alignment (elements to be mutated)
	-association file for matching elements to reference
	-transition/transversion rates file
Usage: python3 create_simulated_subfamily_motifs.py <fastaLike alignment file> <association file> <transition/transversion rate file> <output file basename>
'''

import sys
import random
import statistics as stats

if len(sys.argv) != 5:
	sys.exit(__doc__)

align_file = sys.argv[1]
assoc_file = sys.argv[2]
rates_file = sys.argv[3]
output_basename = sys.argv[-1]

#Function that returns the reverse complement of a sequence
def reverse_complement(seq):
	comp = {'A':'T', 'T':'A', 'C':'G', 'G':'C'}
	reverse_seq = seq[::-1]
	rev_comp = ''
	for base in reverse_seq:
		rev_comp += comp[base]
	return rev_comp

#Dictionary that returns the corresponding base following transition
transition_di = {'A':'G', 'G':'A', 'T':'C', 'C':'T'}

#Dictionary that returns the possible bases following transversion (use random.choice to randomly pick one)
transversion_di = {'A':'TC', 'G':'TC', 'T':'AG', 'C':'AG'}

#Get reference-element associations
assoc_di = {}
with open(assoc_file, 'r') as f:
	for line in f:
		fields = line.rstrip('\n').split('\t')
		ref = fields[0]
		for element in fields[1:]:
			if element == '.':
				continue
			if element in assoc_di:
				assoc_di[element].append(ref)
			else:
				assoc_di[element] = [ref]

#Get transition/transversion rates
species_list = []
rates_di = {}
with open(rates_file, 'r') as f:
	header = f.readline().strip()
	header_fields = header.split('\t')
	if header_fields[0] == 'Species': #Species separated transition/transversion rates
		separate_species = True
		for line in f:
			fields = line.rstrip().split('\t')
			species = fields[0]
			rate_type = fields[1]
			rate = float(fields[2])
			if species not in species_list:
				rates_di[species] = []
				species_list.append(species)
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di[species].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di[species].append(rate)
	elif header_fields[0] == 'Type': #No separate species, just one set of overall transition/transversion rates
		species_list.append('overall')
		rates_di['overall'] = []
		separate_species = False
		for line in f:
			fields = line.rstrip().split('\t')
			rate_type = fields[0]
			rate =float(fields[1])
			if rate_type == 'Conserved':
				continue
			elif rate_type == 'Transitions':
				rates_di['overall'].insert(0, rate)
			elif rate_type == 'Transversions':
				rates_di['overall'].append(rate)
#Check that there is only one transition/transversion rate for each species
for species in species_list:
	if len(rates_di[species]) != 2:
		print(species + ' has more than 1 transition/transversion rate (should only have one of each) in ' + rates_file)
		sys.exit(__doc__)

#Randomly mutate motif positions and surrounding bases using the given transition/transversion rates
#Output mutated sequneces to fastaLike (with unchanged references, for motif conservation) and fasta (for motif scanning) files
list_unanalyzed_elements = []
outputted_references = []
with open(align_file, 'r') as f:
	with open(output_basename + '.fastaLike', 'w') as o1:
		with open(output_basename + '.fa', 'w') as o2:
			for line in f:
				if line.startswith('>'):
					name = line.lstrip('>').rstrip()
					if name not in assoc_di: #element is not associated with any reference in the association file
						list_unanalyzed_elements.append([name, 'no reference'])
						continue #Skip because there is no known reference point for the alignment
					elif len(assoc_di[name]) > 1: #more than one reference for the same element
						list_unanalyzed_elements.append([name, 'multiple references'])
						continue #Skip because it's unknown which is the reference point to use, don't want to double count the same element for motif conservation
					reference = assoc_di[name][0]
					if separate_species:
						if ';' not in name: #No ";" to separate name from species => no species info
							list_unanalyzed_elements.append([name, 'no species'])
							continue #Skip if no species info because can't pick a transition/transversion rate to use
						species = name.split(';')[-1]
						if species not in species_list: #If no transition/transversion rate available for species, skip
							list_unanalyzed_elements.append([name, 'no species transition/transversion rate'])
							continue
					else:
						species = 'overall'
					transition_rate = rates_di[species][0]
					transversion_rate = rates_di[species][1]
					conserved_rate = 1 - transition_rate - transversion_rate
					species_index = species_list.index(species)
					seq = f.readline().rstrip().upper()
					align = f.readline().rstrip().upper()
					#Simulate mutations in the reference sequence
					randomized_seq = ''
					transition_thresh = transition_rate
					transversion_thresh = transition_rate + transversion_rate
					for i in range(len(seq)):
						reference_base = align[i]
						seq_base = seq[i]
						if reference_base == '-': #If gap in reference, keep it as gap
							randomized_seq += reference_base
						elif seq_base == '-': #If gap in element, keep it as gap
							randomized_seq += seq_base
						elif reference_base.strip('ACGT') != '': #If reference base is a non-standard nucleotide, just use the reference base as is
							randomized_seq += reference_base
						else: #If reference has a base, randomly decide if transition, transversion, or conservation
							random_roll = random.random()
							if random_roll < transition_thresh: #If transition, add the corresponding transitioned base to the randomized seq
								randomized_seq += transition_di[reference_base]
							elif random_roll < transversion_thresh: #If transversion, randomly decide which transversion to make
								to_add_random_base = random.choice(transversion_di[reference_base])
								randomized_seq += to_add_random_base
							else: #If conservation, add the unchanged reference base
								randomized_seq += reference_base
					#Write simulated sequences to output
					o1.write('>' + name + '\n')
					o1.write(randomized_seq + '\n')
					o1.write(align + '\n')
					o2.write('>' + name + '\n')
					o2.write(randomized_seq.replace('-', '') + '\n')
					if reference not in outputted_references:
						o2.write('>' + reference + '\n')
						o2.write(align.replace('-', '') + '\n')
						outputted_references.append(reference)

#print('Unanalyzed elements and reason')
#for instance in list_unanalyzed_elements:
#	print(instance[0] + '\t' + instance[1])

