#!/usr/bin/env python3

###############################
#
# Copyright Stanford University 2018
# Author: John Bell
#
# This script converts a vcf file to a list of haplotypes information (which is
#   larger and which is smaller) with the format:
#   chr - pos - (hap 1 base - hap 2 base - phase id)x # samples
# 
# INPUT: a vcf file
# OUTPUT:  a text file with 2 position columns and then
#   3 columns (phase id & base info) for each sample 
#
###############################

#--------------------------- I/O --------------------------------#

import sys,string,re

try:
   filename = sys.argv[1]
except IndexError:
   sys.exit("need input vcf")

if filename.__contains__('vcf'):
   namefix = re.compile(r"""vcf""")
   outname = namefix.sub(r'lists', filename)
else:
   outname = filename + '.lists'

outhandle = open(outname, 'w') 

#--------------------------- end of I/O -------------------------#

#--------------------------- pattern definitions -----------------------#

newline_nuke = re.compile(r"""\n""")

#--------------------------- end of pattern definitions -----------------------#

#--------------------------- function definitions -----------------------#

def get_phase_id_loc(field_info): 
   PS_var = 0
   fct = 0
   fields = field_info.split(':')
   while fields[fct]:
      if fields[fct] == 'PS':
         break
      fct += 1
   return fct

# end of definition

#-------------------------- end of function definitions -----------------------#

#-------------------------- loop --------------------------------# 

for line in open(filename,'r'): 
   if line.startswith('#'):
      dummy = 1
      # do nothing
   else: 
      # do stuff
      str_to_print = ""
      printit = 0
      bits = line.split('\t')
      ref_all = bits[3]
      alt_all = bits[4]

      field_info = bits[8]
# determine phase id loc from field info
      phase_id_loc = get_phase_id_loc(field_info)

      c = 9

      str_to_print = [bits[0], bits[1]] 
# iterate through sample gt info
      while c < len(bits):
         gt_bits = bits[c].split(':')
# get phase id
         phase_id_raw = gt_bits[phase_id_loc]
         phase_id = newline_nuke.sub(r'', phase_id_raw)

# pieces to add
         vars = []
         if gt_bits[0] == '0|1':
            vars = [ ref_all, alt_all ]
         elif gt_bits[0] == '1|0':
            vars = [ alt_all, ref_all ]
         elif gt_bits[0] == '1|1':
            vars = [ alt_all, alt_all ]
         elif gt_bits[0] == '1|2' or gt_bits[0] == '2|1' or gt_bits[0] == '0|2' or gt_bits[0] == '2|0':
            vars = [ 'gt-2', 'gt-2' ]
         else:
            vars = [ 'NP', 'NP' ]
         str_to_print = str_to_print + vars
         str_to_print.append(phase_id)
         c += 1

      print('\t'.join(str_to_print),file=outhandle) 

#-------------------------- end of loop -------------------------#
