#!/usr/bin/env python3

###############################
#
# Copyright Stanford University 2016
# Author: John Bell
#
# This script converts a vcf file from vcf to phased basic, 
# useful for determining overlapping barcodes
#
# INPUT:  vcf file with simplified GT:BX:PS format
# OUTPUT:  file with barcodes information split up, along with 
# genotype and phase id info in different fields, to simplify processing
#  in file which converts SNVs to blocks
# 
# ----------- Notes: 
#
# the barcode listing  (BX in format) is according to ref + 
# alt, in the same way as the allele counts in AD.  So the barcodes 
# listing placement is determined by the haplotype value.  If 
# 1|1, leave in same order as in vcf.  1/0:  treat as 1|0 in 
# terms of formatting results.  leave out 0/0 or 0|0 "calls"
#
# ----------- End of notes
#
################################

import sys,string,re

#--------------------------- I/O ---------------------------------------#

filename = sys.argv[1]
if filename.__contains__('vcf'):
   namefix = re.compile(r"""vcf""")
   outname = namefix.sub(r'phased_basic', filename)
else:
   outname = filename + '.phased_basic'

outhandle = open(outname, 'w') 

#--------------------------- END OF I/O --------------------------------#

#--------------------------- PATTERNS ----------------------------------#

newline_nuke = re.compile(r"""\n""")

#--------------------------- END OF PATTERNS ---------------------------#

#--------------------------- FUNCTION DEFINITIONS ----------------------#

# this function determines if the gt_info is such that the line needs to
#  be printed (i.e. haplotype doesn't include 2 or 3), and, if so,
#  puts information in proper order (which haplotype first per 0|1 or 1|0, etc.)
#  as well as appending the PS value
# the individual array location checks are superfluous if input is
#  simplified as pos. of BX will always be second
def gt_to_basic(field_info, gt_info): 
   local_str_list = []
   PS_var = 0
   fields = field_info.split(':')

   gt_info_nnl = newline_nuke.sub(r'', gt_info) 
   gt1_spl = gt_info_nnl.split(':')
   this_print = 0

   if fields[1] == 'BX':
      hap1_spl = gt1_spl[1].split(',')
      try:
         if fields[2] == 'PS':
            PS_var = gt1_spl[2]
      except IndexError:
         PS_var = 'n/a'
   elif len(fields) > 5:
      if fields[5] == 'BX':
         hap1_spl = gt1_spl[5].split(',')
         try:
            if fields[6] == 'PS':
               PS_var = gt1_spl[6]
         except IndexError:
            PS_var = 'n/a'
      elif len(fields) > 7:
         if fields[7] == 'BX':
            hap1_spl = gt1_spl[7].split(',')
            try:
               if fields[8] == 'PS':
                  PS_var = gt1_spl[8]
            except IndexError:
               PS_var = 'n/a'
         else:
            next
      else:
         next
   else:
      next

   if hap1_spl[0].__contains__('_'):
      hap11_list = hap1_spl[0]
      hap11_ct = len(hap11_list.split(';'))
   else:
      hap11_list = 'n/a'
      hap11_ct = 0

   if hap1_spl[1].__contains__('_'):
      hap12_list = hap1_spl[1]
      hap12_ct = len(hap12_list.split(';'))
   else:
      hap12_list = 'n/a'
      hap12_ct = 0

# this line modified to block cases with non-calls (".") or 
# single values ("1") 5/11/18 jb
   if (gt1_spl[0] == "0|0" or gt1_spl[0] == "0/0" or gt1_spl[0].__contains__('2') or gt1_spl[0].__contains__('3') or gt1_spl[0] == "." or gt1_spl[0] == "1"):
# note my shameful use of a magic number
      this_print = -10
   elif (gt1_spl[0] == "0|1" or gt1_spl[0] == "1|1" or gt1_spl[0] == "0/1" or gt1_spl[0] == "1/1"):
      hap12_list_fix = newline_nuke.sub(r'', hap12_list)
      local_str_list.extend([str(hap11_ct), hap11_list, str(hap12_ct), hap12_list_fix ])
      this_print = 1
   elif gt1_spl[0] == "1|0" or gt1_spl[0] == "1/0":
      hap11_list_fix = newline_nuke.sub(r'', hap11_list)
      local_str_list.extend([str(hap12_ct), hap12_list, str(hap11_ct), hap11_list_fix ])
      this_print = 1

   local_str_list.append(gt1_spl[0]) 

   PS_clean = newline_nuke.sub(r'', PS_var)
   local_str_list.append(str(PS_clean)) 

   return(local_str_list, this_print)
# end of definition

#-------------------------- END OF FUNCTION DEFINITIONS -----------------------#

#-------------------------- MAIN LOOP -----------------------------------------#

for line in open(filename,'r'): 
   if line.startswith('#'):
      # do nothing
      dummy = 1
   else: 
      # all the good stuff
      str_to_print = ""
      printit = 0
      bits = line.split('\t')

      field_info = bits[8]
      c = 9

      str_to_print = [bits[0], bits[1]] 
      while c < len(bits):
         (this_list, this_does_it_print) = gt_to_basic(field_info, bits[c])
         printit += this_does_it_print
         str_to_print.extend(this_list)
         c += 1

      if printit > 0:
         print('\t'.join(str_to_print),file=outhandle) 

#-------------------------- END OF MAIN LOOP ----------------------------------#

