#!/usr/bin/env python

import subprocess
import argparse
from termcolor import colored
from subprocess import Popen, PIPE
import pandas as pd

parser = argparse.ArgumentParser(description="A tool helps to find the potential modification sites")
parser.add_argument("--ref", type=str, metavar="", required=True, help="Input reference (FASTA)")
parser.add_argument("--read", type=str, metavar="", required=True, help="Input reads (FASTQ)")
parser.add_argument("--cpu", type=str, metavar="", required=False, help="CPU number (default:10)", default="10")
parser.add_argument("--method", type=str, metavar="", required=False, help="The strategy for detecting potential modification sites, cutoff or top (default: cutoff)", default="cutoff")
parser.add_argument("--cut", type=float, metavar="", required=False, help="Cutoff value [0, 1] (default:0.35)", default=0.35)
parser.add_argument("--num", type=int, metavar="", required=False, help="Top N sites with the Difference index, sorted from the largest to the smallest (default:2000)", default=2000)
parser.add_argument("--min_depth", type=int, metavar="", required=False, help="The minimum depth (default:50)", default=50)
parser.add_argument("--min_depth_strand", type=int, metavar="", required=False, help="The minimum depth for forward strand and reverse strand (default:25)", default=25)
args = parser.parse_args()

def cmd_shell(cammands, string):
    process = Popen(cammands.split(' '), stdout=subprocess.DEVNULL, universal_newlines=True)
    process.wait()
    err = process.communicate()

    if process.returncode == 0:
        pass
    else:
    	print(colored(err, "red"))

def reads_preprocess():
	ref_index = ["samtools", "faidx", args.ref]
	mapping = ["minimap2", "-ax", "map-ont", "-o", "mapping.sam", "-t", args.cpu, args.ref, args.read]
	view = ["samtools", "view", "-bS", "-F4", "-@", args.cpu, "-o", "mapping.bam", "mapping.sam"]
	sort = ["samtools", "sort", "-@", args.cpu, "-o", "mapping.sort.bam", "mapping.bam"]
	remove = ["rm", "mapping.sam", "mapping.bam"]
	mpileup = ['samtools', 'mpileup', '-q', '30', '--no-output-ins', '--no-output-ins', '--no-output-del', '--no-output-del', '--no-output-ends', '-o', 'mapping.mpileup.txt', 'mapping.sort.bam']

	for i, cmd in enumerate([ref_index, mapping, view, sort, remove, mpileup]):
		try:
			subprocess.run(cmd, check=True)

		except subprocess.CalledProcessError as e:
			print(e.output)
			raise Exception("Data processing failed")

def calculate_distance(chrom, pos, string, cutoff):
	def calcualte_abs(num_1, num_2, total_1, total_2):
		value_1 = int(num_1)/int(total_1)
		value_2 = int(num_2)/int(total_2)
		abs_value = float(format(abs(value_1 - value_2), '.3f'))
		return(abs_value)

	dict = {}
	for i in string:
		dict[i] = dict.get(i, 0) + 1

	for i in ["A", "T", "G", "C", "a", "t", "g", "c"]:
		if i not in dict.keys():
			dict[str(i)] = 0 

	Forward_depth = dict["A"] + dict["C"] + dict["G"]+ dict["T"]
	Reverse_depth = dict["a"] + dict["c"] + dict["g"]+ dict["t"]
	min_depth = args.min_depth_strand

	if Forward_depth >= min_depth and Reverse_depth >= min_depth:
		absA = calcualte_abs(dict["A"], dict["a"], Forward_depth, Reverse_depth)
		absT = calcualte_abs(dict["T"], dict["t"], Forward_depth, Reverse_depth) 
		absG = calcualte_abs(dict["G"], dict["g"], Forward_depth, Reverse_depth)
		absC = calcualte_abs(dict["C"], dict["c"], Forward_depth, Reverse_depth)
		Manhattan_distance = absA + absT + absC + absG
		MAE = float(format(Manhattan_distance / 2, '.3f'))

		if MAE >= float(cutoff):
			mes = str(chrom) + "\t" + str(int(pos) -1) + "\t" + str(pos) + "\t" + str(MAE) + "\t"
			mes += str(absA) + "\t" + str(absT) + "\t" + str(absG) + "\t" + str(absC) + "\t"
			mes += str(dict["A"]) + "," + str(dict["T"]) + "," + str(dict["G"]) + "," + str(dict["C"]) + ","
			mes += str(dict["a"]) + "," + str(dict["t"]) + "," + str(dict["g"]) + "," + str(dict["c"])
			return(mes)
	else:
		pass

def get_direction_info(cutoff, output_name):
	header = "Chr\tStart\tEnd\tDifference_index\tDif_A\tDif_T\tDif_G\tDif_C\tA,T,G,C,a,t,g,c"
	with open(output_name, "w") as fo:
		fo.write(header + "\n")
		with open("mapping.mpileup.txt") as ff:
			for base in ff:
				base = base.replace("\n", "")
				base = base.split("\t")

				if int(base[3]) >= args.min_depth:
					chrom = str(base[0])
					pos = str(base[1])
					string = str(base[4])
					
					if calculate_distance(chrom, pos, string, float(cutoff)):
						fo.write(calculate_distance(chrom, pos, string, float(cutoff)) + "\n")
		ff.close()
	fo.close()

def get_top_N_sites(top_number):
	output_name = "potential_modification_site.txt"
	df = pd.read_table("difference_index_table.txt", delimiter='\t')
	df_sort = df.nlargest(int(top_number), "Difference_index")
	df_sort.to_csv(output_name, index=False, sep="\t")


def get_potential_modification_bed():
	output_name = "potential_modification_site.bed"
	df = pd.read_table("potential_modification_site.txt", delimiter='\t')
	bedfile = df.loc[:,["Chr", "Start", "End"]]
	bedfile.to_csv(output_name, index=False, sep="\t", header=False)

def get_enrich_read():
	df = pd.read_table("potential_modification_site.txt", delimiter='\t')
	df["sequence_start"] = df["Start"] - 10
	df["sequence_end"] = df["End"] + 10
	enriched_bed = df.loc[:,["Chr", "sequence_start", "sequence_end"]]
	enriched_bed_1 = enriched_bed[enriched_bed["sequence_start"] > 0]
	enriched_bed_1.to_csv("enrichment.bed", index=False, sep="\t", header=False)

def get_enriched_motif_info():
	enrichment = ["bedtools", "getfasta", "-bed", "enrichment.bed", "-fi", args.ref, "-fo", "enrichment.fa"]
	for i, cmd in enumerate([enrichment]):
		try:
			subprocess.run(cmd, check=True)
		except subprocess.CalledProcessError as e:
			print(e.output)
			raise Exception("Getting enrichment sequences failed.")

if __name__=="__main__":

	if args.method == "top":
		reads_preprocess()
		get_direction_info(0, "difference_index_table.txt")
		get_top_N_sites(args.num)


	elif args.method == "cutoff":
		reads_preprocess()
		get_direction_info(args.cut, "potential_modification_site.txt")
	else:
		print("Please recheck your method!!!")

	get_potential_modification_bed()
	get_enrich_read()
	get_enriched_motif_info()
