#!/usr/bin/env python3


###############################
#
# Copyright Stanford University 2015
# Author: John Bell
#
# This script creates haplotype lists based on two files of format: 
# variant file:
# var_chr8    69519777        T       T       40068451637    
#
# blocks file:
# 1       3024265 3173228 104     102805896       m       M       102778567       m       M       103024265       M       m 
# and also needs a third input saying which sample to use from blocks file (1 2 3 )
#
# Input:  variant file (output of phased_vcf_to_haplotype_prelist_numbers.py), a barcode blocks file (output of blocks_to_list.py), and an integer saying which sample to use from blocks file
# Output:  a text file list of haplotypes by position 
#
################################

#--------------------------- I/O --------------------------------#

import sys,string,re

# read in all lines
try:
   infile = open(sys.argv[1],'r')
   block_lines = infile.read().split("\n")
except IndexError:
   sys.exit("need input file of blocks info")

# subtract 2:  1 b/c 0-based, 1 b/c splits to put empty line after last '\n'
end_of_blocks_file = len(block_lines) - 2 

try:
   which_sample = int(sys.argv[3])
except IndexError:
   sys.exit('need third input saying which sample (1,2,3 etc.)')

#--------------------------- end of I/O -------------------------#

#--------------------------- initializations --------------------#

# determine column positions for reading out of files
b_id_loc = 4 + (which_sample - 1)*3 
b_hap1_loc = b_id_loc + 1
b_hap2_loc = b_id_loc + 2

var_hap1_loc = 2 +  (which_sample - 1)*3 
var_hap2_loc = var_hap1_loc + 1
var_id_loc = var_hap1_loc + 2

# get a first line of block info to serve as comparison at beginning of loop
b_ct = 0
b_bits = block_lines[b_ct].split()
b_chr = int(b_bits[0])
b_beg_pos = int(b_bits[1])
b_end_pos = int(b_bits[2])
b_id = b_bits[b_id_loc]
b_hap1 = b_bits[b_hap1_loc]
b_hap2 = b_bits[b_hap2_loc]

#--------------------------- end of initializations -------------#

#--------------------------- pattern definitions ----------------#

newline_nuke = re.compile(r"""\n""")

#--------------------------- end of pattern definitions ---------#

#--------------------------- loop -------------------------------#

for line in open(sys.argv[2], 'r'): 

   # elements of variant info
   bits = line.split('\t') 
   var_chr = int(bits[0])
   pos = int(bits[1])
   hap1 = bits[var_hap1_loc]
   hap2 = bits[var_hap2_loc]
   phase_id_raw = bits[var_id_loc]
   phase_id = newline_nuke.sub(r'', phase_id_raw)

   if var_chr == b_chr and pos >= b_beg_pos and pos <= b_end_pos and  phase_id == b_id:
      # so if we have match between the ids, showing that the current variants belong to whichever block we have info for
      if b_hap1 == 'm':
         # we put the "smaller" haplotype first by convention; hence if the block haplotype 1 is minor, we needn't change anything
         # also recall that by convention if the two counts are equal in the block, we give both 'm', so it doesn't matter which is "small"
         out_hap1 = hap1
         out_hap2 = hap2
      else:
         out_hap1 = hap2
         out_hap2 = hap1
   elif var_chr < b_chr or (var_chr == b_chr and pos < b_beg_pos) or b_ct >= end_of_blocks_file:
      # in this case variant has no block to go into, so print variant unhaplotyped 
      #  and then just increment variant
      out_hap1 = 'na'
      out_hap2 = 'na'
   elif var_chr > b_chr or (var_chr == b_chr and pos > b_end_pos):
      # we update block here until position equals or exceeds variant pos
      #  this is also where we can have cases if blocks are left out or variant isn't haplotyped in a block, so need to test if variant falls between two blocks, in which case we print it as in case above b/c next loop goes to next variant
      while (var_chr > b_chr or (var_chr == b_chr and pos > b_end_pos)) and b_ct < end_of_blocks_file:
         # now increment block info
         b_ct = b_ct + 1 
         b_bits = block_lines[b_ct].split()
         b_chr = int(b_bits[0])
         b_beg_pos = int(b_bits[1])
         b_end_pos = int(b_bits[2])
         b_id = b_bits[b_id_loc]
         b_hap1 = b_bits[b_hap1_loc]
         b_hap2 = b_bits[b_hap2_loc]
      if var_chr == b_chr and pos >= b_beg_pos and pos <= b_end_pos and  phase_id == b_id:
         if b_hap1 == 'm':
            out_hap1 = hap1
            out_hap2 = hap2
         else:
            out_hap1 = hap2
            out_hap2 = hap1
      elif var_chr < b_chr or (var_chr == b_chr and pos < b_beg_pos):
         out_hap1 = 'na'
         out_hap2 = 'na'
      else:
         print('something is wrong with these:')
         print(line)
         print(block_lines[b_ct])
         sys.exit('so we die until fixed')
   else:
      # doesn't match any possibility, so something is wrong, unless my logic above is, so print out what we're looking at and die   
      print('something is wrong with these:')
      print(var_chr, b_chr, pos, b_beg_pos, pos, b_end_pos, phase_id, b_id)
      print(line)
      print(block_lines[b_ct])
      sys.exit('so we die until fixed')

# at the end of every row we print the variant one way or another:
   print(var_chr,pos,out_hap1,out_hap2,sep='\t')

#--------------------------- end of loop ------------------------# 
