#!/bin/env python

#This script associates phased, het GATK SNPs to the correct phased deletion event flanking region.
#It segregates in-phase and out-of-phase SNPs into two corresponding files. 
#The first argument is the regions file. The second argument is the file containing phased GATK variants. 
#The files should have been processes as mentioned in the README. 
#The third and fourth arguments are self-explanatory.

from __future__ import print_function
import sys
import os

snpCount = 0 #snps that pass all tests below
hetCount = 0 #all phased het (0/1) snps. 
snpCnt = dict(pInPhs=0, pOutPhs=0, mInPhs=0, mOutPhs=0)
rpts = []

inPhsVCF = open(sys.argv[3], "a")
outPhsVCF = open(sys.argv[4], "a")

print("##fileformat=VCFv4.1", file=inPhsVCF)
print('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">', file=inPhsVCF)
print("##INFO=<ID=DELGT,Number=1,Type=String,Description=\"Deletion phased genotype\">", file=inPhsVCF)
print("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA19240", file=inPhsVCF)

print("##fileformat=VCFv4.1", file=outPhsVCF)
print('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">', file=outPhsVCF)
print("##INFO=<ID=DELGT,Number=1,Type=String,Description=\"Deletion phased genotype\">", file=outPhsVCF)
print("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA19240", file=outPhsVCF)

lftOrRit = 0#0 means left
#regions file with phased, het regions (e.g., deletion events). 
for line in open(sys.argv[1]):
  flds = line.strip().split()
  if flds[1][0] == "1": #1|0 form
    pDel = 1
    mDel = 0
  else:
    pDel = 0
    mDel = 1
  rnge = flds[0].split(":")[1].split("-")
  try:
    ll = int(rnge[0])
  except ValueError:
    print(line)
  rl = int(rnge[1])
  cmd = " ".join(("tabix", sys.argv[2], flds[0], "> tmpFle")) #the phased GATK VCF.gz file containing all phased het variants.
  os.system(cmd)
  for line1 in open("tmpFle"):
    dntCnt = 0
    vcfFlds = line1.strip().split()
    vcfChr = vcfFlds[0]
    vcfPos = int(vcfFlds[1])
    vcfREF = vcfFlds[3]
    vcfALT = vcfFlds[4]
    if len(vcfREF) == 1 and len(vcfALT) == 1: 
      snpCoord = vcfChr+" "+vcfFlds[1]  #building a simple str of snp coord. to hold in this list.
      if snpCoord in rpts:              #we have to check for repeats to keep them out.
        dntCnt = 1
      else:
        rpts.append(snpCoord)
      if (ll <= vcfPos <= rl) and (not dntCnt):
        snpCount += 1
        if lftOrRit == 0:
          relP = vcfPos-rl
        else:
          relP = vcfPos-ll
        if vcfFlds[9] == "0|1":
          if mDel == 1:
            snpCnt["mInPhs"] += 1
            print("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\tRP={11};RGN={7};DELGT={8}\t{9}\t{10}".format(vcfChr, vcfPos,".", vcfREF, vcfALT, ".", vcfFlds[6], flds[0], flds[1], "GT", vcfFlds[9], relP), file=inPhsVCF)
          else:
            snpCnt["mOutPhs"] += 1
            print("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\tRP={11};RGN={7};DELGT={8}\t{9}\t{10}".format(vcfChr, vcfPos,".", vcfREF, vcfALT, ".", vcfFlds[6], flds[0], flds[1], "GT", vcfFlds[9], relP), file=outPhsVCF)
        else: #of the form "1|0" i.e., paternal
          if pDel == 1:
            snpCnt["pInPhs"] += 1
            print("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\tRP={11};RGN={7};DELGT={8}\t{9}\t{10}".format(vcfChr, vcfPos,".", vcfREF, vcfALT, ".", vcfFlds[6], flds[0], flds[1], "GT", vcfFlds[9], relP), file=inPhsVCF)
          else:
            snpCnt["pOutPhs"] += 1
            print("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\tRP={11};RGN={7};DELGT={8}\t{9}\t{10}".format(vcfChr, vcfPos,".", vcfREF, vcfALT, ".", vcfFlds[6], flds[0], flds[1], "GT", vcfFlds[9], relP), file=outPhsVCF)
        
  lftOrRit =  0 if lftOrRit else 1      

print("snpCount is {0}".format(snpCount)) 
