#!/usr/bin/env python

"""
Calculates the site coverage of supermatrix based on different levels (0,10,20,30,..,100)
Input is supermatrix and format
Output is a table containing number of site occupancy
    0 10 20 30 40 ... 100
mat1
mat2
mat3
"""

import numpy as np

import Bio
from Bio import AlignIO
from Bio.Alphabet import IUPAC, Gapped

from collections import defaultdict

import csv
import sys

import os

"""
Get all sequence files in a directory
"""
def get_all_files(directory):
	file_list = [os.path.join(directory,fn) for fn in os.listdir(directory) if fn.endswith(".nex")]
	return file_list



"""
Reads sequence file using Bio.SeqIO
Returns a numpy array of the sequences and the binary sequences (1 if non gap, 0 if gap)
"""
def read_sequence_file(filename,_format,vserbose):
	GAP_CHARS = frozenset(['X', '?', '-'])
	if verbose:
		print("Loading sequence")
	if verbose:
		print("Sequence Loaded")
	sequences = np.array([])
	sequences_binary = np.array([])
	if verbose:
		print("Checking for gap characters")
	seq_num=1
	if _format=="nexus":
		alignment = AlignIO.read(open(filename), _format,alphabet=Gapped(IUPAC.protein))
	else:
		alignment = AlignIO.read(open(filename), _format)
	for sequence in alignment:
		if verbose:
			print("Loading sequence {}".format(seq_num))
		seq = sequence.seq
		# Create binary array for non-gap (1) / gap (0).
		seq_bin = np.array([(0 if x in GAP_CHARS else 1) for x in seq],dtype=np.uint8)
		if sequences_binary.size > 0:
			#sequences = np.append(sequences,[seq],axis=0)
			sequences_binary = np.append(sequences_binary,[seq_bin],axis=0)
		else:
			#sequences = np.array([seq])
			sequences_binary = np.array([seq_bin])
		seq_num+=1
	return sequences_binary


"""
Create numpy array of percentages for each site number
Go site by site through sequences_binary array - loop through columns
Count number of '1' out of total
"""
def find_site_occupancy(sequences_binary,verbose):
	if verbose:
		print("Calculating site occupancy")
	mat = np.asmatrix(sequences_binary)
	sites_sum = np.array(np.matrix.sum(mat,axis=0))[0]
	occupancy = sites_sum/mat.shape[0]
	return occupancy


"""
Threshold given as decimal percentage
Goes site by site through site_occ to see whether site_occ is above threshold
If above threshold, then the site occupancy (number of non-gaps) is above given percentage
We count site if this is case
"""
def count_sites_above_threshold(site_occ,threshold,verbose):
	if verbose:
		print("Recording site occupancy for size {}%".format(threshold))
	#Convert threshold to decimal percentage
	threshold = threshold/100
	total=0
	for site_percentage in site_occ:
		if site_percentage >= threshold:
			total+=1
	return total


"""
Writes output to file
"""
def write_file(compiled_dict,output,row,verbose):
	if verbose:
		print("Writing output to file {}",output)
	with open(output,'w') as out_fp:
		tsv_writer = csv.writer(out_fp,delimiter="\t")
		tsv_writer.writerow(row)
		for filename in compiled_dict:
			newrow = [filename]
			for threshold in sorted(compiled_dict[filename]):
				newrow.append(compiled_dict[filename][threshold])
			tsv_writer.writerow(newrow)


def gather_results(site_occ,verbose):
	if verbose:
		print("Gathering results")
	results_dict = defaultdict(list)
	for threshold in range(0,101,5):
		num_sites = count_sites_above_threshold(site_occ,threshold,verbose)
		results_dict[threshold] = num_sites
	return results_dict





def write_reduced_supermatrix(filename,_format,site_occ,threshold,out_mat,verbose):
	if verbose:
		print("Reducing matrix to site occupancy: {}".format(threshold))
	threshold = int(threshold)/100
	#Find sites which do not meet occupancy criteria
	res=[]
	for i in range(len(site_occ)):
		if site_occ[i] < threshold:
			res.append(i)

	seq_num = 1
	if _format=="nexus":
		alignment = AlignIO.read(open(filename), _format,alphabet=Gapped(IUPAC.protein))
	else:
		alignment = AlignIO.read(open(filename), _format)
	for sequence in alignment:
		if verbose:
			print("Loading sequence {}".format(seq_num))
		_seq = np.array(sequence.seq)
		reduced_seq = np.delete(_seq,res)
		reduced_seq = ''.join(list(map(str,reduced_seq)))
		sequence.seq = Bio.Seq.Seq(reduced_seq,alphabet=Gapped(IUPAC.protein))
		seq_num+=1



	with open(out_mat+".phy",'w') as out_fp:
		AlignIO.write(alignment,out_fp,'phylip')

	with open(out_mat+".fas",'w') as out_fp:
		AlignIO.write(alignment,out_fp,'fasta')

	with open(out_mat+".nex",'w') as out_fp:
		AlignIO.write(alignment,out_fp,'nexus')


def main(directory,output,out_dir,write_threshold,verbose=True):
	verbose = False
	_format = "nexus"

	file_list = get_all_files(directory)
	compiled_dict = defaultdict(dict)
	for filename in file_list:
		#Load sequences
		sequences_binary = read_sequence_file(filename,_format,verbose)
		#Calculate site occupancy as percentage of non-gap characters per site
		site_occ = find_site_occupancy(sequences_binary,verbose)

		results_dict = gather_results(site_occ,verbose)
		if verbose:
			for threshold in sorted(results_dict):
				print("Threshold: {}% \t Num Sites: {}".format(threshold,results_dict[threshold]))

		out_mat = os.path.basename(filename)
		out_mat = out_mat.split(".")[0]
		out_mat += "siteoccupancy_{}".format(str(write_threshold))
		out_mat = os.path.join(out_dir,out_mat)
		row = [x for x in sorted(results_dict)]
		row.insert(0,"")
		compiled_dict[filename] = results_dict
		write_reduced_supermatrix(filename,_format,site_occ,write_threshold,out_mat,verbose)

	write_file(compiled_dict,output,row,verbose)




if __name__ == '__main__':
	directory = sys.argv[1]
	output = sys.argv[2]
	out_dir = sys.argv[3]
	write_threshold = sys.argv[4]
	verbose=False
	main(directory,output,out_dir,write_threshold,verbose)








































































