#!/usr/bin/env python3

###############################
#
# Copyright Stanford University 2016
# Author: John Bell
#
# This script converts a phased basic file into a list of haplotype 
# blocks if there are two input samples
#
# INPUT:  phased basic file with 2 position columns followed 
# by six columns/sample:  # hap 1 barcodes, hap 1 barcodes list, 
#   # hap 2 barcodes, hap 2 barcodes list, genotype, PS value 
# OUTPUT:  haplotype blocks file with 21 columns 
# 
# ----------- Notes: 
#
# a block is created if in a region where all PS values remain 
# constant, there are at least 2 phased hets (in both samples,
# at the same position in both samples).  The phased hets do 
# _not_ need to be at the beginning or end of the block
#
# this script takes data from the "basic" format, which is still 
# 1 snv per row, and converts the data into blocks.
#
# Blocks are calculated according to areas with consistent phase 
# ids (the last column for each sample in the input file),
# and a block requires at least two phased hets somewhere within 
# it where phased het means phased on both samples
#
# Output structure is chr, beginning of block, end of block, 
# position, phase ids for each sample, total SNVs, total # phased 
# hets, proportion of matching haplotypes between samples, then 
# counts of barcodes for each sample:  total, total unique 
# (i.e. different barcode), total in haplotype 1, total unique 
# in haplotype 1, total in haplotype 2, total unique in haplotype 2
#
# The general structure of the program is to check if the current 
# Phase ids match the previous SNVs.  If not, then check to 
# see if there's enough phased SNVs to print the block, and 
# if there are, count up the total barcodes and unique barcodes 
# for each haplotype and print everything including the overall 
# proportion of matching haplotypes.  Also, re-initialize all 
# the count variables.
#
# After assessing whether the phase ids match (and printing 
# and re-initializing variables if needed), the line gets read.  
# If all samples have phased heterozygotes, each haplotype 
# array has the barcodes for that haplotype accreted to it.  
# Only the actual barcode is retained.  Genotypes are also accreted 
# in separate arrays so their amount of congruency between samples 
# can be calculated.
# 
# ----------- End of notes
#
################################

import re,sys,string,getopt,math

def usage():
   print("usage:  phased_basic_to_haplotype_barcode_summary.double.py [--size [100]] <phased basic file>\n where size is the number of het phased SNPs required\n")
   sys.exit()
   return 0

#-------------------------- GET OPTIONS IF THEY EXIST -----------------------------#

# default threshold for number of SNPs
size_threshold = 100

try:
   opts, args = getopt.getopt(sys.argv[1:], '', ['size='])
except getopt.GetoptError:
   usage()  
for opt, arg in opts:
   if opt == '--size':
      size_threshold = int(arg)

try:
   infile = sys.argv[-1]
except:
   usage()
  
#----------------------------- END OF OPTIONS SECTION -------------------

#----------------------------- FUNCTION DEFINITIONS ---------------------

def bc_breakdown(thisarr):
# count number of tokens of each barcode
   code_cts = {}
   try:
      for codes in thisarr:
         try:
            code_cts[codes] += 1
         except KeyError:
            code_cts[codes]  = 1
   except IndexError:
      code_cts = 0
   return code_cts

def get_hash_vals(bc_list):
# produces total number of barcodes & number of unique barcodes from one array
   total = 0
   uniques = 0
   for keys in bc_list:
      if bc_list[keys] > 0:
         total += bc_list[keys]
         uniques += 1
   pieces = [total, uniques]
   return pieces

def bcs_to_counts(thislist):
# uses above functions to turn list of barcodes into totals & total uniques
   thatlist = bc_breakdown(thislist)
   returned = get_hash_vals(thatlist)
   return returned

def compare_haps(arr1, arr2):
# counts up number of matching genotypes for each snv in each 
# block and takes quotient to get proportion.
# 
# If number is very low, then highly congruent (haplotypes in 
# opposite orientation), and if very high, then highly congruent;
# but if in the middle, suggests a switch error
   if (len(arr1) != len(arr2)):
      sys.exit('arrays in compare_haps are different lengths')
   else: 
      count = 0
      a = 0
      while a < len(arr1):
         if (arr1[a] == arr2[a]):
            count += 1
         a += 1
      quot = (math.floor(count / len(arr1) * 1000 + .5))/1000
      return quot 

#----------------------------- END OF FUNCTION DEFINITIONS ---------------------

# print header
print("chr\tbeg_pos\tend_pos\tdist\tnormal_PS\ttumor_PS\tall_SNVs\tphased_het\thaplotypes_match\tnorm_total\tnorm_unique\tnorm_hap1_total\tnorm_hap1_unique\tnorm_hap2_total\tnorm_hap2_unique\ttum_total\ttum_unique\ttum_hap1_total\ttum_hap1_unique\ttum_hap2_total\ttum_hap2_unique")

#------------------------------- INITIALIZATIONS FOR LOOP -----------------------
line_count = 1

# PS_val variable will test for beginning and endings of new blocks.  
# PS_val variable can be a sum of the two PS's if want to test for 
# changes in either.  That's because both monotonically nondecrease 
# throughout the vcf file, therefore impossible to have a conflict:  
# if a+b=c, then for d+e=c where d=a+x (i.e. d > a), x>0, then 
# e=b-x, which is impossible (the second ps val would have had 
# to decrease from b to e while the other was increasing).
old_PS_val = 0
new_PS_val = 0

#-------------------------------- THIS IS THE MAIN LOOP ------------------------# 

for line in open(infile, "r"):
   bits = line.split('\t')
   new_PS_val = bits[7] + bits[13]

   if (new_PS_val != old_PS_val):
# then end of the current block.
# Are there enough SNPS in the block to be interesting?
# If there are, then print.  In any case, re-initialize variables 
# after checking whether to print
      try:
         if (phased_snps >= size_threshold):  
# determine haplotype relations
            haplo_relation_1 = compare_haps(norm_haps, tum_haps)

# get block results
            n_hap1_returned = bcs_to_counts(n_hap1_bcs)
            n_hap1_total_bcs = n_hap1_returned[0]
            n_hap1_unique_bcs = n_hap1_returned[1]

            n_hap2_returned = bcs_to_counts(n_hap2_bcs) 
            n_hap2_total_bcs = n_hap2_returned[0]
            n_hap2_unique_bcs = n_hap2_returned[1]

            t_hap1_returned = bcs_to_counts(t_hap1_bcs) 
            t_hap1_total_bcs = t_hap1_returned[0]
            t_hap1_unique_bcs = t_hap1_returned[1]

            t_hap2_returned = bcs_to_counts(t_hap2_bcs) 
            t_hap2_total_bcs = t_hap2_returned[0]
            t_hap2_unique_bcs = t_hap2_returned[1]

# what is size of block?
            dist = endpos - begpos + 1
# count up the barcodes
            n_total_bcs = int(n_hap1_total_bcs) + int(n_hap2_total_bcs) 
            n_unique_bcs = int(n_hap1_unique_bcs) + int(n_hap2_unique_bcs)

            t_total_bcs = int(t_hap1_total_bcs) + int(t_hap2_total_bcs) 
            t_unique_bcs = int(t_hap1_unique_bcs) + int(t_hap2_unique_bcs)

# print the whole line out now         
            dataline =  [str(chr), str(begpos), str(endpos), str(dist), str(norm_ps), str(tumor_ps), str(block_size), str(phased_snps), str(haplo_relation_1), str(n_total_bcs), str(n_unique_bcs), str(n_hap1_total_bcs), str(n_hap1_unique_bcs), str(n_hap2_total_bcs), str(n_hap2_unique_bcs), str(t_total_bcs), str(t_unique_bcs), str(t_hap1_total_bcs), str(t_hap1_unique_bcs), str(t_hap2_total_bcs), str(t_hap2_unique_bcs) ]
            print('\t'.join(dataline))
      except NameError:
         dummy = 1
         # do nothing, because it means things not defined yet -- first line
      # end of "if block size passes threshold"

      # initialize everything for a new block
      chr = bits[0]
      begpos = int(bits[1])
      endpos = int(bits[1])
      norm_ps = bits[7]
      tumor_ps_raw = bits[13]
      newline_nuke = re.compile(r"""\n""")
      tumor_ps = newline_nuke.sub(r'', tumor_ps_raw)
      norm_haps = []      
      tum_haps = []
      n_hap1_bcs = []
      n_hap2_bcs = []
      t_hap1_bcs = []
      t_hap2_bcs = []
      phased_snps = 0
      block_size = 0
      n_hap1_total_bcs = 0
      n_hap1_unique_bcs = 0 
      n_hap2_total_bcs = 0
      n_hap2_unique_bcs = 0 
      t_hap1_total_bcs = 0
      t_hap1_unique_bcs = 0 
      t_hap2_total_bcs = 0
      t_hap2_unique_bcs = 0 
  # end of if new_PS_val is new compared to old_PS_val.

# use try-except to initialize counts as needed or to increment them
   try:
      block_size += 1
   except NameError:
      block_size = 1
  # if calls are heterozygous 
   if ((bits[6] == '0|1' or bits[6] == '1|0') and (bits[12] == '0|1' or bits[12] == '1|0')):
      try:
         phased_snps += 1
      except NameError:
         phased_snps = 1
      try:
         norm_haps.append(bits[6])
      except NameError:
         norm_haps = []
         norm_haps.append(bits[6])
      try:
         tum_haps.append(bits[12])
      except NameError:
         tum_haps = []
         tum_haps.append(bits[12])
# then add these to set of bcs for each of 4 haplotypes
# assuming there are real bar codes

# create pattern recognition to get rid of post-barcode scores
#  and accrete barcodes for each snv to total list of barcodes 
# for each haplotype
      barc_match = re.compile(r"""-\d(_\d+)+""")
      n_hap1_bcs_loc_raw = bits[3]
      if n_hap1_bcs_loc_raw != 'n/a':
         n_hap1_bcs_loc_raw_fix = barc_match.sub(r'', n_hap1_bcs_loc_raw) 
         n_hap1_bcs_loc = n_hap1_bcs_loc_raw_fix.split(';')
         try:
            n_hap1_bcs = n_hap1_bcs + n_hap1_bcs_loc
         except NameError:
            n_hap1_bcs = n_hap1_bcs_loc
      n_hap2_bcs_loc_raw = bits[5]
      if n_hap2_bcs_loc_raw != 'n/a':
         n_hap2_bcs_loc_raw_fix = barc_match.sub(r'', n_hap2_bcs_loc_raw)
         n_hap2_bcs_loc = n_hap2_bcs_loc_raw_fix.split(';')
         try:
            n_hap2_bcs = n_hap2_bcs + n_hap2_bcs_loc
         except NameError:
            n_hap2_bcs = n_hap2_bcs_loc
      t_hap1_bcs_loc_raw = bits[9]
      if t_hap1_bcs_loc_raw != 'n/a':
         t_hap1_bcs_loc_raw_fix = barc_match.sub(r'', t_hap1_bcs_loc_raw)
         t_hap1_bcs_loc = t_hap1_bcs_loc_raw_fix.split(';')
         try:
            t_hap1_bcs = t_hap1_bcs + t_hap1_bcs_loc
         except NameError:
            t_hap1_bcs = t_hap1_bcs_loc
      t_hap2_bcs_loc_raw = bits[11]
      if t_hap2_bcs_loc_raw != 'n/a':
         t_hap2_bcs_loc_raw_fix = barc_match.sub(r'', t_hap2_bcs_loc_raw)
         t_hap2_bcs_loc = t_hap2_bcs_loc_raw_fix.split(';')
         try:
            t_hap2_bcs = t_hap2_bcs + t_hap2_bcs_loc
         except NameError:
            t_hap2_bcs = t_hap2_bcs_loc
# new end position is most recent position
   endpos = int(bits[1])
# increment/copy test variables
   old_PS_val = new_PS_val
   line_count += 1   

#-------------------------------------- END OF MAIN LOOP ------------------------# 

#--------------------------------------- print last one now that loop is over

# following just a repetition of original count and print statements 
# to get the last line (since loop ends before final printing)

if (phased_snps >= size_threshold):  
# determine haplotype relations
   haplo_relation_1 = compare_haps(norm_haps, tum_haps)

# get block results
   n_hap1_returned = bcs_to_counts(n_hap1_bcs)
   n_hap1_total_bcs = n_hap1_returned[0]
   n_hap1_unique_bcs = n_hap1_returned[1]

   n_hap2_returned = bcs_to_counts(n_hap2_bcs) 
   n_hap2_total_bcs = n_hap2_returned[0]
   n_hap2_unique_bcs = n_hap2_returned[1]

   t_hap1_returned = bcs_to_counts(t_hap1_bcs) 
   t_hap1_total_bcs = t_hap1_returned[0]
   t_hap1_unique_bcs = t_hap1_returned[1]

   t_hap2_returned = bcs_to_counts(t_hap2_bcs) 
   t_hap2_total_bcs = t_hap2_returned[0]
   t_hap2_unique_bcs = t_hap2_returned[1]

# print basic position info & ps values & haplo relations
   dist = endpos - begpos + 1

# barcode sums
   n_total_bcs = n_hap1_total_bcs + n_hap2_total_bcs 
   n_unique_bcs = n_hap1_unique_bcs + n_hap2_unique_bcs

   t_total_bcs = t_hap1_total_bcs + t_hap2_total_bcs 
   t_unique_bcs = t_hap1_unique_bcs + t_hap2_unique_bcs

   dataline =  [str(chr), str(begpos), str(endpos), str(dist), str(norm_ps), str(tumor_ps), str(block_size), str(phased_snps), str(haplo_relation_1), str(n_total_bcs), str(n_unique_bcs), str(n_hap1_total_bcs), str(n_hap1_unique_bcs), str(n_hap2_total_bcs), str(n_hap2_unique_bcs), str(t_total_bcs), str(t_unique_bcs), str(t_hap1_total_bcs), str(t_hap1_unique_bcs), str(t_hap2_total_bcs), str(t_hap2_unique_bcs) ]

# print the whole line out now         
   print('\t'.join(dataline))
#--------------------------------------- end of print last call

