'''
Finds the motifs that enriched in +annotation elements relative to -annotation elements.
Uses Fisher's exact test to determine enrichment. Uses Benjamini-Hochberg procedure to control false discovery rate for each annotation.
Contingency table:		+annotation	-annotation
	elements_wMotif		    ##		    ##
	elements_without	    ##		    ##
Usage: python3 find_motifs_assoc_wAnnotation_fishers.py <input motif file> <input annotation file> <output file>
'''

import sys
import numpy as np
import scipy.stats

if len(sys.argv) != 4:
	sys.exit(__doc__)

fdr = 0.05 #False discovery rate is 0.05, change as necessary

def p_adjust_bh(p):
	p = np.asfarray(p)
	by_descend = p.argsort()[::-1]
	by_orig = by_descend.argsort()
	steps = float(len(p)) / np.arange(len(p), 0, -1)
	q = np.minimum(1, np.minimum.accumulate(steps * p[by_descend]))
	return q[by_orig]

di_motifs = {}
order_elements = []
with open(sys.argv[1], 'r') as f:
	header = f.readline().rstrip('\n')
	fields = header.split('\t')
	list_motifs = []
	for motif in fields[1:]:
		motif_name = motif.split('_')[0]
		list_motifs.append(motif_name)
	num_motifs = len(list_motifs)
	for line in f:
		fields = line.rstrip('\n').split('\t')
		name = fields[0].split(';')[0]
		di_motifs[name] = fields[1:]
		order_elements.append(name)

di_fishers_cat = {}
with open(sys.argv[2], 'r') as f:
	annotation_header = f.readline().rstrip('\n')
	fields = annotation_header.split('\t')
	list_annotations = []
	for annotation in fields[6:]:
		list_annotations.append(annotation)
	num_annotations = len(list_annotations)
	for line in f:
		fields = line.rstrip('\n').split('\t')
		name = fields[0] + ':' + fields[1] + '-' + fields[2] + '(' + fields[5] + ')'
		element_fishers_cat = []
		for i in range(len(fields[6:])):
			element_annotation = fields[i+6]
			element_fishers_cat.append([])
			element_motifs = di_motifs[name]
			for motif in element_motifs:
				if int(motif) >= 1: #Has at least 1 motif
					if int(element_annotation) >= 1: #Overlaps annotation
						element_fishers_cat[i].append('+motif+annotation')
					else: #Does not overlap with annotation
						element_fishers_cat[i].append('+motif-annotation')
				else: #Does not have motif
					if int(element_annotation) >= 1: #Overlaps annotation
						element_fishers_cat[i].append('-motif+annotation')
					else: #Does not overlap with annotation
						element_fishers_cat[i].append('-motif-annotation')
		di_fishers_cat[name] = element_fishers_cat

#output = []
with open(sys.argv[-1], 'w') as o:
	o.write('Annotation\tMotif\tOdds_ratio\tp-adjusted\t+Motif+Annotation\t+Motif-Annotation\t-Motif+Annotation\t-Motif-Annotation\n')
	fishers_tests = []
	list_pvals = []
	for i in range(num_annotations): #For each annotation
		annotation = list_annotations[i]
		category_counts = []
		for motif in list_motifs:
			category_counts.append([0, 0, 0, 0])
		#Get counts for contingency tables for each motif
		for name in order_elements:
			for j in range(num_motifs):
				category = di_fishers_cat[name][i][j]
				if category == '+motif+annotation':
					category_counts[j][0] += 1
				elif category == '+motif-annotation':
					category_counts[j][1] += 1
				elif category == '-motif+annotation':
					category_counts[j][2] += 1
				elif category == '-motif-annotation':
					category_counts[j][3] += 1
		#Perform fisher's exact tests for each motif
		for j in range(num_motifs):
			odds, pval = scipy.stats.fisher_exact([category_counts[j][:2], category_counts[j][2:]])
			list_pvals.append(pval)
			fishers_tests.append([annotation, list_motifs[j], str(odds), str(category_counts[j][0]), str(category_counts[j][1]), str(category_counts[j][2]), str(category_counts[j][3])])
	#Multiple hypothesis correct using Benjamini-Hochberg method
	list_padj = p_adjust_bh(list_pvals)
	for i in range(len(fishers_tests)):
		fishers_test = fishers_tests[i]
		padj = list_padj[i]
		if padj < fdr:
			new_line = '\t'.join(fishers_test[:3]) + '\t' + str(padj) + '\t' + '\t'.join(fishers_test[3:]) + '\n'
			o.write(new_line)

