#!/usr/bin/env python3

###############################
#
# Copyright Stanford University 2016
# Author: John Bell
#
# This script converts a haplotypes block file to haplotype information (which is
#   larger and which is smaller) with the format:
#   chr - beg pos - end pos - phased hets - (phase id - 1st hap - 2nd hap)x # samples
#   where hap info is m or M for smaller or larger
# 
# INPUT: a haplotype blocks file
# OUTPUT:  a text file with 3 position columns, a column indicating # phased hets, and then
#   3 columns (phase id & haplotype info) for each sample 
#
###############################

#--------------------------- I/O --------------------------------#

import sys,string,re

# read in all lines
try:
   infile = open(sys.argv[1],'r')
   lines = infile.read().split("\n")
except IndexError:
   sys.exit("need input file")

#--------------------------- end of I/O -------------------------#

#--------------------------- initializations --------------------#

# count words/line -- determine samples #
num_words = len(lines[1].split())

num_samples = 0
if num_words == 29:
   num_samples = 3
elif num_words == 21:
   num_samples = 2
elif num_words == 13:
   num_samples = 1

phased_hets_pos = 5 + num_samples

#-------------------------- end of initializations --------------#

#-------------------------- pattern definitions -----------------------#

newline_nuke = re.compile(r"""\n""")

#-------------------------- end of pattern definitions -----------------------#

#-------------------------- loop --------------------------------#

for line in lines: 
   if line.startswith('#') or line.startswith('chr') or line.__eq__(''):
   # last condition above deals with the fact that splitting on '\n' gives an empty line at the end
      dummy = 1
      # do nothing
   else: 
      # do stuff
      str_to_print = ""
      printit = 0
      bits = line.split('\t')
      str_to_print = [bits[0], bits[1], bits[2], bits[phased_hets_pos]] 

      c = 1
      while c <= num_samples:
         phase_id = bits[3+c]
         hap1_unique_pos = 6+2*num_samples-1 + 4 + 6*(c-1) - 1
         hap1_unique = int(bits[hap1_unique_pos])
         hap2_unique = int(bits[hap1_unique_pos+2])
         if hap1_unique < hap2_unique: 
            this_order = "m\tM"
         elif hap1_unique > hap2_unique:
            this_order = "M\tm"
         else:  
# put minor first if a tie
            this_order = "m\tm"
         piece_to_add = [ phase_id, this_order ]
         str_to_print = str_to_print + piece_to_add 
         c = c+1
      print('\t'.join(str_to_print)) 

#-------------------------- end of loop -------------------------#

