import sys
import argparse
from Bio import SeqIO

def parseArgument():
        # Parse the input
    parser=argparse.ArgumentParser(description=\
                        "For each SNP, get the sequence of its location +/- distance with the binding and non-binding alleles and record in fasta format")
    parser.add_argument("--SNPFILENAME", required=True,\
                        help='File with SNPs, their locations, and their alleles')
    parser.add_argument("--GENOMEFILENAME", required=True,\
                        help='File with genome that SNPs are from')
    parser.add_argument("--DISTANCE", type=int, required=False,\
                        help='Distance around SNPs to obtain')
    parser.add_argument("--usePeak", action="store_true", required=False,\
                        help='Instead of using the distances around SNPs to define the sequences, use the peaks containing the SNPs')
    parser.add_argument("--BINDALLELESEQFILENAME", required=True,\
                        help='Output file with sequence surrounding first allele')
    parser.add_argument("--NONBINDALLELEFILENAME", required=True,\
                        help='Output file with sequence surrounding second allele')
    parser.add_argument("--chromCol", type=int, required=False, default=0,\
                        help='Column number with the chromosome information, zero-indexed')
    parser.add_argument("--noChr", action="store_true", required=False,\
                        help='The chromosome names have no chr')
    parser.add_argument("--positionCol", type=int, required=False, default=1,\
                        help='Column number with the position information, zero-indexed')
    parser.add_argument("--alleleCol", type=int, required=False, default=2,\
                        help='Column number with the position information, zero-indexed')
    parser.add_argument("--onlyAlternateAllele", action="store_true", required=False,\
                        help='Only the alternate allele is present in the file')
    parser.add_argument("--SNPNameCol", type=int, required=False, default=None,\
                        help='Column number with the SNP name, zero-indexed')
    parser.add_argument("--peakStartCol", type=int, required=False, default=None,\
                        help='Column number with the peak start information, zero-indexed')
    parser.add_argument("--peakEndCol", type=int, required=False, default=None,\
                        help='Column number with the peak end information, zero-indexed')
    options = parser.parse_args()
    return options

def getSNPLocation(SNPLine, options):
	# Get the SNP and peak location from the current line
    if SNPLine == "":
        # At the end of the SNP to peak file, so stop
        return [("", 0), ("", ""), ""]
    SNPLineElements = SNPLine.split()
    SNPName = None
    if options.SNPNameCol != None:
        # Get the SNP name
        SNPName = SNPLineElements[options.SNPNameCol]
    else:
        # Make the SNP name based on the chromosome and position
        SNPName = SNPLineElements[options.chromCol] + ":" + SNPLineElements[options.positionCol] + "_" + SNPLineElements[options.alleleCol]
    SNPAlleles = [SNPLineElements[options.alleleCol]]
    if not options.onlyAlternateAllele:
        # Both alleles are in the file
        SNPAlleles = [SNPLineElements[options.alleleCol], SNPLineElements[options.alleleCol + 1]]
    peakLocation = None
    if options.usePeak:
        peakLocation = (int(SNPLineElements[options.peakStartCol]), int(SNPLineElements[options.peakEndCol]))
    if not options.noChr:
        # The chromosome names are complete
        return [(SNPLineElements[options.chromCol], int(SNPLineElements[options.positionCol])), \
			SNPAlleles, SNPName, peakLocation]
    else:
        # Add chr to the beginning of the chromosome name
        return [("chr" + SNPLineElements[options.chromCol], int(SNPLineElements[options.positionCol])), \
			SNPAlleles, SNPName, peakLocation]
	
def makeSNPLocationDict(options):
	# Make a dictionary that maps chromosomes to SNP locations and alleles
	SNPFile = open(options.SNPFILENAME)
	chromToSNPsDict = {}
	for line in SNPFile:
		# Iterate through the SNPs and put each in a dictionary
		[SNPLocation, SNPAlleles, SNPName, peakLocation] = getSNPLocation(line.strip(), options)
		if SNPLocation[0] not in chromToSNPsDict:
			# Add the current chromosome to the SNP dictionary
			chromToSNPsDict[SNPLocation[0]] = [[SNPLocation, SNPAlleles, SNPName, peakLocation]]
		else:
			# Add the current SNP to the dictionary entry for the current chromosome
			chromToSNPsDict[SNPLocation[0]].append([SNPLocation, SNPAlleles, SNPName, peakLocation])
	SNPFile.close()
	return chromToSNPsDict

def getSequencesSurroundingSNPs(options):
    # For each SNP, get the sequence of its location +/- distance with the binding and non-binding alleles and record in fasta format
    chromToSNPsDict = makeSNPLocationDict(options)
    bindAlleleSeq = ""
    nonBindAlleleSeq = ""
    bindAlleleSeqFile = open(options.BINDALLELESEQFILENAME, 'w+')
    nonBindAlleleSeqFile = open(options.NONBINDALLELEFILENAME, 'w+')
    for seqRecord in SeqIO.parse(options.GENOMEFILENAME, "fasta"):
		# Iterate through the chromosomes and get the sequences surrounding each SNP in each chromosome
		# Combine SNPs that are in the same peak, and ASSUME THAT THEY ARE IN LD AND THE BINDING ALLELES CORRESPOND TO EACH OTHER
        if seqRecord.id not in chromToSNPsDict:
            # The current chromosome is not in the dictionary, so skip it
            continue
        for [SNPLocation, SNPAlleles, SNPName, peakLocation] in chromToSNPsDict[seqRecord.id]:
            # Iterate through all SNPs on the current chromosome
            if SNPLocation[1] == peakLocation[0]:
                # The SNP seems to be in the base before the peak, suggestion that there is an annotation error
                print("Potential annotation error for SNP " + SNPName + " in peak " + str(peakLocation[0]))
                continue
            if not options.usePeak:
                # Use the same distance around each allele
                bindAlleleSeq = list(str(seqRecord.seq[SNPLocation[1] - options.DISTANCE:SNPLocation[1] + options.DISTANCE]))
                nonBindAlleleSeq = list(str(seqRecord.seq[SNPLocation[1] - options.DISTANCE:SNPLocation[1] + options.DISTANCE]))
                if not options.onlyAlternateAllele:
                    # Both alleles are present
                    assert((bindAlleleSeq[options.DISTANCE - 1:options.DISTANCE - 1 + len(SNPAlleles[0])].upper() == SNPAlleles[0]) or \
				        (bindAlleleSeq[options.DISTANCE - 1:options.DISTANCE - 1 + len(SNPAlleles[1])].upper() == SNPAlleles[1]))
                    bindAlleleSeq[options.DISTANCE - 1 + len(SNPAlleles[0]):len(bindAlleleSeq)] = \
                            bindAlleleSeq[options.DISTANCE:len(bindAlleleSeq) - len(SNPAlleles[0]) + 1]
                    bindAlleleSeq[options.DISTANCE - 1:options.DISTANCE - 1 + len(SNPAlleles[0])] = SNPAlleles[0]
                    nonBindAlleleSeq[options.DISTANCE - 1 + len(SNPAlleles[1]):len(nonBindAlleleSeq)] = \
                            nonBindAlleleSeq[options.DISTANCE:len(nonBindAlleleSeq) - len(SNPAlleles[1]) + 1]
                    nonBindAlleleSeq[options.DISTANCE - 1::options.DISTANCE - 1 + len(SNPAlleles[1])] = SNPAlleles[1]
                else:
                    # Assume that the binding allele is the reference allele
                    nonBindAlleleSeq[options.DISTANCE - 1 + len(SNPAlleles[0]):len(nonBindAlleleSeq)] = \
                            nonBindAlleleSeq[options.DISTANCE:len(nonBindAlleleSeq) - len(SNPAlleles[0]) + 1]
                    nonBindAlleleSeq[options.DISTANCE - 1:options.DISTANCE - 1 + len(SNPAlleles[0])] = SNPAlleles[0]
            else:
                bindAlleleSeq = list(str(seqRecord.seq[peakLocation[0]:peakLocation[1]]))
                nonBindAlleleSeq = list(str(seqRecord.seq[peakLocation[0]:peakLocation[1]]))
                if not options.onlyAlternateAllele:
                    # Both alleles are present
                    assert((bindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1:SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[0])].upper() == SNPAlleles[0]) \
                            or \
                            (bindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1::SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[1])].upper() == SNPAlleles[1]))
                    bindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[0]):len(bindAlleleSeq)] = \
                            bindAlleleSeq[SNPLocation[1] - peakLocation[0]:len(bindAlleleSeq) - len(SNPAlleles[0]) + 1]
                    bindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1:SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[0])] = SNPAlleles[0]
                    nonBindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[1]):len(nonBindAlleleSeq)] = \
                            nonBindAlleleSeq[SNPLocation[1] - peakLocation[0]:len(nonBindAlleleSeq) - len(SNPAlleles[1]) + 1]
                    nonBindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1:SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[1])] = SNPAlleles[1]
                else:
                    # Assume that the binding allele is the reference allele
                    nonBindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[0]):len(nonBindAlleleSeq)] = \
                            nonBindAlleleSeq[SNPLocation[1] - peakLocation[0]:len(nonBindAlleleSeq) - len(SNPAlleles[0]) + 1]
                    nonBindAlleleSeq[SNPLocation[1] - peakLocation[0] - 1:SNPLocation[1] - peakLocation[0] - 1 + len(SNPAlleles[0])] = SNPAlleles[0]
            SNPNameSuffix = ""
            if SNPName != None:
                # Add a suffix to each header with the SNP name
                SNPNameSuffix = "_" + SNPName
            if not options.usePeak:
                # Obtain the coordinates of the sequence in the fasta entry and use that in the header
                bindAlleleSeqFile.write(">" + SNPLocation[0] + ":" + str(SNPLocation[1] - options.DISTANCE + 1) + "-" + \
				    str(SNPLocation[1] + options.DISTANCE) + SNPNameSuffix + "\n")
                nonBindAlleleSeqFile.write(">" + SNPLocation[0] + ":" + str(SNPLocation[1] - options.DISTANCE + 1) + "-" + \
				    str(SNPLocation[1] + options.DISTANCE) + SNPNameSuffix + "\n")
            else:
                # The put the peak coordinates in the fasta header
                 bindAlleleSeqFile.write(">" + SNPLocation[0] + ":" + str(peakLocation[0]) + "-" + \
                    str(peakLocation[1]) + SNPNameSuffix + "\n")
                 nonBindAlleleSeqFile.write(">" + SNPLocation[0] + ":" + str(peakLocation[0]) + "-" + \
                    str(peakLocation[1]) + SNPNameSuffix + "\n")
            bindAlleleSeqFile.write("".join(bindAlleleSeq).upper() + "\n")
            nonBindAlleleSeqFile.write("".join(nonBindAlleleSeq).upper() + "\n")
    bindAlleleSeqFile.close()
    nonBindAlleleSeqFile.close()

if __name__=="__main__":
	options = parseArgument()
	getSequencesSurroundingSNPs(options)
