#!/usr/bin/env python3

###############################
#
# Copyright Stanford University 2016
# Author: John Bell
#
# this script integrates multiple single-sample vcf files but 
# requires that each have GT:BX:PS gt format and all chromosomes 
# be numerically sorted
# 
# note that "integrate" here means print ONLY the positions 
# where all input samples have a call
#
# note usage below is nn1 = non-normal 1, nn2 = non-normal 2 (because could be dyspl-malig, malig-met, met-met, etc.)
#
# INPUT:  multiple vcf files with no headers and simplified to GT:BX:PS genotypes
# OUTPUT:  a headless vcf file containing the intersection of positions in inputs
#
################################


import sys

file_ct = 3  # assume 3 files unless fails

# check to make sure at least two files
try:
   nt = open(sys.argv[1]).readlines()
   nn1 = open(sys.argv[2]).readlines()
except IndexError:
   print("need at least two vcf files\n")
   sys.exit()

# check for third
try:
   nn2 = open(sys.argv[3]).readlines()
except IndexError:
   file_ct = 2

# get file lengths
nt_len = len(nt)
nn1_len = len(nn1)
if file_ct == 3:
   nn2_len = len(nn2)

# initialize counters
nt_ct  = 0
nn1_ct = 0
if file_ct == 3:
   nn2_ct = 0 

#------------------------ if two files are input ----------------------------#

if file_ct == 2:
   while nt_ct < nt_len and nn1_ct < nn1_len:
      nt_bits  = nt[nt_ct].split('\t')
      nn1_bits = nn1[nn1_ct].split('\t')

      if len(nt_bits) > 10:
         nt_ct += 1
         print("more than one sample or weird line at line ", nt_ct, " :\n",nt[nt_ct], sep='')
         sys.exit()
      elif len(nn1_bits) > 10:
         nn1_ct += 1
         print("more than one sample or weird line at line ", nn1_ct, " :\n",nn1[nn1_ct], sep='')
         sys.exit()

      nt_chr  = int(nt_bits[0])
      nt_pos  = int(nt_bits[1])
      nn1_chr = int(nn1_bits[0])
      nn1_pos = int(nn1_bits[1])
      nt_gt_fields = nt_bits[8]
      nn1_gt_fields = nn1_bits[8]
      nt_gt = nt_bits[9]
      nn1_gt = nn1_bits[9]

# test that gt fields are simplified
      if nt_gt_fields != "GT:BX:PS":
         nt_ct += 1
         print("gt_fields in line ", nt_ct," not simplified\n", nt[nt_ct],sep='')
         sys.exit()
      elif nn1_gt_fields != "GT:BX:PS":
         print("gt fields in line ", nn1_ct, " not simplified\n", nn1[nn1_ct], sep='')
         sys.exit()
# now check for precedence and position matching
      if nt_chr == nn1_chr and nt_pos == nn1_pos: 
# modified from original script to take linefrom TUMOR file 
# not NORMAL, for downstream mega-haplotype list file processing
         comb_bits = nn1_bits[0:9]
         comb_bits.append(nt_bits[9])
         comb_bits.append(nn1_bits[9])
         line_to_print = '\t'.join(comb_bits)
         line_to_print_clean = line_to_print.replace('\n','')
         print(line_to_print_clean,end='\n')
         nt_ct += 1
         nn1_ct += 1
      elif (nt_chr < nn1_chr) or (nt_chr == nn1_chr and nt_pos < nn1_pos):
         nt_ct += 1
      elif (nt_chr > nn1_chr) or (nt_chr == nn1_chr and nt_pos > nn1_pos):
         nn1_ct += 1
      else:
         print("problem calculating which chr is first:\n",nt[nt_ct],nn1[nn1_ct],end='')
         sys.exit()

#------------------------ end of if two files are input ----------------------------#

#------------------------ if three files are input ----------------------------#

elif file_ct == 3:
   while nt_ct < nt_len and nn1_ct < nn1_len and nn2_ct < nn2_len:
      nt_bits  = nt[nt_ct].split('\t')
      nn1_bits = nn1[nn1_ct].split('\t')
      nn2_bits = nn2[nn2_ct].split('\t')

      if len(nt_bits) > 10:
         nt_ct += 1
         print("more than one sample or weird line at line ", nt_ct, " :\n",nt[nt_ct], sep='')
         sys.exit()
      elif len(nn1_bits) > 10:
         nn1_ct += 1
         print("more than one sample or weird line at line ", nn1_ct, " :\n",nn1[nn1_ct], sep='')
         sys.exit()
      elif len(nn2_bits) > 10:
         nn2_ct += 1
         print("more than one sample or weird line at line ", nn2_ct, " :\n",nn2[nn2_ct], sep='')
         sys.exit()

# split up lines as needed
      nt_chr  = int(nt_bits[0])
      nt_pos  = int(nt_bits[1])
      nn1_chr = int(nn1_bits[0])
      nn1_pos = int(nn1_bits[1])
      nn2_chr = int(nn2_bits[0])
      nn2_pos = int(nn2_bits[1])
# and extract gt info
      nt_gt_fields =  nt_bits[8]
      nn1_gt_fields = nn1_bits[8]
      nn2_gt_fields = nn2_bits[8]
      nt_gt =  nt_bits[9]
      nn1_gt = nn1_bits[9]
      nn2_gt = nn2_bits[9]

# test that gt fields are simplified
      if nt_gt_fields != "GT:BX:PS":
         nt_ct += 1
         print("gt_fields in line ", nt_ct," not simplified\n", nt[nt_ct],sep='')
         sys.exit()
      elif nn1_gt_fields != "GT:BX:PS":
         nn1_ct += 1
         print("gt fields in line ", nn1_ct, " not simplified\n", nn1[nn1_ct], sep='')
         sys.exit()
      elif nn2_gt_fields != "GT:BX:PS":
         nn2_ct += 1
         print("gt fields in line ", nn2_ct, " not simplified\n", nn2[nn2_ct], sep='')
         sys.exit()

# now check for precedence and position matching
      if nt_chr == nn1_chr and nt_pos == nn1_pos and nt_chr == nn2_chr and nt_pos == nn2_pos: 
# identical positions so make into strings for concatenating and concatenate
         comb_bits = nn2_bits[0:9]
         comb_bits.append(nt_bits[9])
         comb_bits.append(nn1_bits[9])
         comb_bits.append(nn2_bits[9])
         line_to_print = '\t'.join(comb_bits)
         line_to_print_clean = line_to_print.replace('\n','')
         print(line_to_print_clean,end='\n')
         nt_ct += 1
         nn1_ct += 1
         nn2_ct += 1
# if chroms are same, check which pos needs to be updated
      elif nt_chr == nn1_chr and nt_chr == nn2_chr:
         if nt_pos == min(nt_pos, nn1_pos, nn2_pos):
            nt_ct += 1
         elif nn1_pos == min(nt_pos, nn1_pos, nn2_pos):
            nn1_ct += 1
         elif nn2_pos == min(nt_pos, nn1_pos, nn2_pos):
            nn2_ct += 1
         else:
            nt_ct += 1
            nn1_ct += 1
            nn2_ct += 1
            print("problem calculating precedence in lines", nt_ct, nn1_ct, nn2_ct, " which are:\n", nt[nt_ct],nn1[nn1_ct],nn2[nn2_ct],sep=' ',end='')
            sys.exit()
# if chroms aren't same, find min and update then test again
      elif nt_chr == min(nt_chr, nn1_chr, nn2_chr):
         nt_ct += 1
      elif nn1_chr == min(nt_chr, nn1_chr, nn2_chr):
         nn1_ct += 1
      elif nn2_chr == min(nt_chr, nn1_chr, nn2_chr):
         nn2_ct += 1
      else:
         nt_ct += 1
         nn1_ct += 1
         nn2_ct += 1
         print("problem calculating precedence in lines", nt_ct, nn1_ct, nn2_ct, " which are:\n", nt[nt_ct],nn1[nn1_ct],nn2[nn2_ct],sep=' ',end='')
         sys.exit()

#------------------------ end of if three files are input ----------------------------#
