#!/usr/bin/env python
# coding: utf-8

# In[9]:


# By Jeonghwan Youk
# for ID signature
# python2.7
# require pysam module

import os,sys,pysam

def ID_repeat(input_line,ref_fasta):
    input_split = input_line.split('\t')
    input_chr = input_split[0]
    input_pos = long(input_split[10])
    
    input_ref = input_split[8]
    input_alt = input_split[9]
    
    if len(input_ref) == 1: # for small insertion
        ru = input_alt[1:]
        rc = 0
        
        while rc<5:
            if ref_fasta.fetch(input_chr,input_pos+len(ru)*rc,input_pos+len(ru)*(rc+1)) == ru:
                rc+=1
            else:
                break
        return ['insertion',str(ru),str(rc),'noMH']

    else: # for deletion    
        if len(input_ref) == 2:
            ru = input_ref[1:]
            rc = 0
            
            while rc<5:
                if ref_fasta.fetch(input_chr,input_pos+len(ru)*(rc+1),input_pos+len(ru)*(rc+2)) == ru:
                    rc+=1
                else:
                    break
            return ['deletion',str(ru),str(rc+1),'noMH']            
            
            
        else:
            ru = input_ref[1:]
            rc = 0
            
            if ref_fasta.fetch(input_chr,input_pos+len(ru)*(rc+1),input_pos+len(ru)*(rc+2)) == ru:
                while rc<5:
                    if ref_fasta.fetch(input_chr,input_pos+len(ru)*(rc+1),input_pos+len(ru)*(rc+2)) == ru:
                        rc+=1
                        #print ref_fasta.fetch(input_chr,input_pos+len(ru)*rc,input_pos+len(ru)*(rc+1))
                    else:
                        break
                return ['deletion',str(ru),str(rc+1),'noMH']
            else: # microhomology check!
                rt_mh = 0
                lt_mh = 0
                for i in range(0,len(ru)):
                    if ref_fasta.fetch(input_chr,input_pos+len(ru),input_pos+len(ru)+1+i) == input_ref[1:1+1+i]:
                        rt_mh+=1
                    else:
                        break
                for j in range(0,len(ru)):
                    if ref_fasta.fetch(input_chr,input_pos-1-j,input_pos) == input_ref[-1-j:]:
                        lt_mh+=1
                    else:
                        break                                            
                if rt_mh == 0 & lt_mh ==0:
                    return ['deletion',str(ru),str(rc+1),'noMH']
                else:
                    #print input_line
                    return ['deletion',str(ru),'0',str(max(rt_mh,lt_mh))]
                  

info_fn = sys.argv[1] # path for vcf list
ref_fa = sys.argv[2] # path for reference fasta ex> /home/users/jhyouk/99_reference/human/GRCh37/human_g1k_v37.fasta

file_list = []
info_file = open(info_fn)
info_line = info_file.readline().strip()
print 'List of target files'
while info_line:
    file_list.append(info_line)
    info_line = info_file.readline().strip()

for input_fn in file_list:
    print input_fn
    input_file = open(input_fn)
    output_file = file(input_fn.replace(".vcf",".context_anno.vcf"),"w")
    input_line = input_file.readline().strip()
    while input_line[0:2] =='##':
        input_line = input_file.readline().strip()
    output_file.write(input_line+'\tID_signature_type;repeat_unit;repat_count;microhomology\n')
    input_line = input_file.readline().strip()
    
    while input_line:
        r_file=pysam.FastaFile(ref_fa)
        if len(input_line.split('\t')[3]) > 11 or len(input_line.split('\t')[4]) > 11 :
            'blank'
        elif len(input_line.split('\t')[3]) == 1 and len(input_line.split('\t')[4]) == 1:
            'blank'
        else:
            ID_list = ID_repeat(input_line,r_file)
            output_file.write(input_line + '\t' + ';'.join(ID_list) + '\n')

        input_line = input_file.readline().strip()


    output_file.close()
    
print 'Finish making all .context_anno.vcf'

#2nd part of indel_signature
#83 context
#'1bp_del'
#'1bp_ins'
#'2bp_del_repeat'
#'2bp_ins'
#'2bp_del_micohomology'

output_file_total = file(info_fn.replace(".txt","_IDSig_matrix.txt"),"w")
n=0
for input_fn in file_list:   
    input_fn = input_fn.replace(".vcf",".context_anno.vcf")
    input_file = file(input_fn)
    out_file=file(input_fn.replace(".vcf",".indel_type_anno.vcf"),"w")
    head_list = ['SampleID','d1c1', 'd1c2', 'd1c3', 'd1c4', 'd1c5', 'd1c6', 'd1t0', 'd1t1', 'd1t2', 'd1t3', 'd1t4', 'd1t5', 'i1c1', 'i1c2', 'i1c3', 'i1c4', 'i1c5', 'i1c6', 'i1t0', 'i1t1', 'i1t2', 'i1t3', 'i1t4', 'i1t5', 'd22_1', 'd22_2', 'd22_3', 'd22_4', 'd22_5', 'd22_6', 'd23_1', 'd23_2', 'd23_3', 'd23_4', 'd23_5', 'd23_6', 'd24_1', 'd24_2', 'd24_3', 'd24_4', 'd24_5', 'd24_6', 'd25_1', 'd25_2', 'd25_3', 'd25_4', 'd25_5', 'd25_6', 'i22_0', 'i22_1', 'i22_2', 'i22_3', 'i22_4', 'i22_5', 'i23_0', 'i23_1', 'i23_2', 'i23_3', 'i23_4', 'i23_5', 'i24_0', 'i24_1', 'i24_2', 'i24_3', 'i24_4', 'i24_5', 'i25_0', 'i25_1', 'i25_2', 'i25_3', 'i25_4', 'i25_5', 'mh2_1', 'mh3_1', 'mh3_2', 'mh4_1', 'mh4_2', 'mh4_3', 'mh5_1', 'mh5_2', 'mh5_3', 'mh5_4', 'mh5_5']
    if n==0:
        output_file_total.write('\t'.join(head_list) + "\n")
        n+=1
    
    del_1bp = []
    ins_1bp = []
    del_2bp = []
    ins_2bp = []
    del_mh = []

    for i in range(0,12):
        del_1bp.append(0)
        ins_1bp.append(0)
    for i in range(0,24):    
        del_2bp.append(0)
        ins_2bp.append(0)
    for i in range(0,11):
        del_mh.append(0)

    input_line = input_file.readline().strip()
    out_file.write(input_line+"\ttype\n")
    input_line = input_file.readline().strip()

    while input_line:
        input_info = input_line.split('\t')[-1] ####
            
        info_split = input_info.split(';')

        if len(info_split[1]) == 1:
            if info_split[0] == 'deletion':
                homopolymer=min(int(info_split[2]),6)
                if info_split[1] == 'C' or info_split[1] == 'G':
                    del_1bp[0+homopolymer-1] += 1
		    out_file.write(input_line+"\td1c"+str(homopolymer)+"\n")
                else:
                    del_1bp[6+homopolymer-1] += 1  
		    out_file.write(input_line+"\td1t"+str(homopolymer)+"\n")
            else:
                homopolymer=min(int(info_split[2]),5)
                if info_split[1] == 'C' or info_split[1] == 'G':
                    ins_1bp[0+homopolymer]+=1          
		    out_file.write(input_line+"\ti1c"+str(homopolymer)+"\n")
                else:
                    ins_1bp[6+homopolymer]+=1
		    out_file.write(input_line+"\ti1t"+str(homopolymer)+"\n")
        elif info_split[0] == 'insertion':
            ins_length = min(len(info_split[1]),5)
            homopolymer=min(int(info_split[2]),5)
            ins_2bp[6*(ins_length-2)+homopolymer]+=1
	    out_file.write(input_line+"\ti2"+str(ins_length)+"_"+str(homopolymer)+"\n")

        else:
            if info_split[3] == 'noMH':
                del_length = min(len(info_split[1]),5)
                homopolymer=min(int(info_split[2]),6)        
                del_2bp[6*(del_length-2) + homopolymer-1] += 1
		out_file.write(input_line+"\td2"+str(del_length)+"_"+str(homopolymer)+"\n")
            else:
                del_length = min(len(info_split[1]),5)
                mh = min(int(info_split[3]),5)
                if del_length == 2:
                    del_mh[0] += 1
		    out_file.write(input_line+"\tmh"+str(del_length)+"_"+str(mh)+"\n")
                elif del_length == 3:
                    del_mh[mh] += 1
		    out_file.write(input_line+"\tmh"+str(del_length)+"_"+str(mh)+"\n")
                elif del_length == 4:
                    del_mh[3-1+mh] += 1
		    out_file.write(input_line+"\tmh"+str(del_length)+"_"+str(mh)+"\n")
                else:
                    del_mh[6-1+mh] += 1
		    out_file.write(input_line+"\tmh"+str(del_length)+"_"+str(mh)+"\n")
        input_line = input_file.readline().strip()

    temp_total = ""
    for i in del_1bp:
        temp_total = temp_total + '\t' + str(i)
    for i in ins_1bp:
        temp_total = temp_total + '\t' + str(i)
    for i in del_2bp:
        temp_total = temp_total + '\t' + str(i)
    for i in ins_2bp:
        temp_total = temp_total + '\t' + str(i)
    for i in del_mh:
        temp_total = temp_total + '\t' + str(i)
    
    input_id = input_fn.split('/')[-1].replace(".context_anno.vcf","")
    output_file_total.write(input_id + temp_total + '\n')

output_file_total.close()
print 'Finish making ID_sig_matrix'
print 'Remove intermediate files'
"""
for input_fn in file_list:   
    input_fn = input_fn.replace(".vcf",".context_anno.vcf")
    os.unlink(input_fn)
"""    
print 'THE END'


# In[ ]:



