#!/usr/bin/env python3
# Roger Volden

'''
Usage:
python3 mmSAM_to_identity.py minimap_output.sam >out
Takes a SAM file from minimap2 and returns percent identity for each read.

Updated usage:
changed to take a directory that contains sam files
python3 mmSAM_to_identity.py path/to/sam/files >out

Output format:
readName    identity (decimal)  # matches    # mismatches   # indels
'''

import sys
import re
import os
import numpy as np

def readSAM(inFile):
    '''
    Reads the sam file, parses MD and CIGAR strings, and prints the read
    name, percent identity, number of matches, number of mismatches,
    and the number of indels.
    '''
    used = set()
    identities=[]
    lengths=[]
    readLengths=[]
    not_aligned=[]
    matchAll=0
    denomAll=0
    mismatchAll=0
    InsertionAll=0
    DeletionAll=0
    total=0
    perfect=0

    chrDict={}

    for line in open(inFile):
        if line.startswith('@'):
            continue
        line = line.rstrip().split('\t')
        name = line[0]
        if len(line)>9:
            readLength=len(line[9])
        else:
            readLength=1
#        chromosome=line[2]
        chromosome='All'
        if chromosome not in chrDict:
            chrDict[chromosome]={}
            chrDict[chromosome]['matchAll']=0
            chrDict[chromosome]['denomAll']=0
            chrDict[chromosome]['mismatchAll']=0
            chrDict[chromosome]['InsertionAll']=0
            chrDict[chromosome]['DeletionAll']=0
            chrDict[chromosome]['identities']=[]
            chrDict[chromosome]['lengths']=[]
            chrDict[chromosome]['readLengths']=[]
        readDirection=line[1]
        bitwise=format(int(readDirection), "b")[::-1]
        paddedbitwise=bitwise+'000000000'
        notAligned=paddedbitwise[2]
        if notAligned=='1':
            not_aligned.append(readLength)
            print(name + '\t' + '-' + '\t' + '-'
                  + '\t' + '-' + '\t' + '-' + '\t' + str(readLength)+'\t*\t*\t*')

        if notAligned == '1':
            continue
#        coverage=int(name.split('_')[6])
        coverage='1'
#        accuracies=float(name.split('_')[9])
        readDirection=line[1]
        alignDirection='*'
        for entry in line[8:]:
            if 'ts:A:' in entry:
                alignDirection=entry.split('ts:A:')[1]

#        coverage=paddedbitwise[6]
#        coverage=int(name.split('_')[2])
#        coverage='0'
#        print(coverage, paddedbitwise, line[1])
#        if coverage>=5 and 'medaka' in name and accuracies>0:
#        if 'medaka' in name and coverage>=4:
        if coverage:
            if total> 500000:
                break
            ambig=0
            used.add(name)
            for column in line:
                if column.startswith('NM:i:'):
                    NM = int(column[5:])
                if column.startswith('nn:i:'):
                    ambig = int(column[5:])
            denominator, M, insertion, deletion = parseCIGAR(line[5])
            indel=insertion+deletion
            mismatch = NM - indel - ambig
            matches = M - mismatch
        # print(denominator, M, indel)
            identity = matches / denominator
            print(name + '\t' + str(identity) + '\t' + str(matches)
                  + '\t' + str(mismatch) + '\t' + str(indel) + '\t' + str(readLength)+'\t'+readDirection+'\t'+alignDirection+'\t'+chromosome )

            if mismatch<5:
                total+=1
                if identity==1:
                    perfect+=1
                chrDict[chromosome]['matchAll']+=matches
                chrDict[chromosome]['denomAll']+=denominator
                chrDict[chromosome]['mismatchAll']+=mismatch
                if insertion not in chrDict[chromosome]:
                    chrDict[chromosome][insertion]=0
                chrDict[chromosome][insertion]+=1
                chrDict[chromosome]['InsertionAll']+=insertion
                chrDict[chromosome]['DeletionAll']+=deletion
                chrDict[chromosome]['identities'].append(identity)
                chrDict[chromosome]['lengths'].append(matches+mismatch)
                chrDict[chromosome]['readLengths'].append(readLength)

#        if len(identities)==1000:

#            sys.stderr.write(str(np.median(identities))+' '+str(np.median(readLengths)) +' '+ str(len(identities)) + ' ' + str(len(not_aligned)) +' '+str(matchAll/denomAll)+ ' ' + str(mismatchAll/denomAll) + ' '+ str(InsertionAll/denomAll)+ ' '+ str(DeletionAll/denomAll)+'\n')
#            used = set()
#            identities=[]
#            lengths=[]
#            readLengths=[]
#            not_aligned=[]
#            matchAll=0
#            denomAll=0
#            mismatchAll=0
#            InsertionAll=0
#            DeletionAll=0

    for chromosome in chrDict:
#      if 'K12' in chromosome:
#        print(chromosome)
        identities = chrDict[chromosome]['identities']
        matchAll = chrDict[chromosome]['matchAll']
        denomAll = chrDict[chromosome]['denomAll']
        mismatchAll = chrDict[chromosome]['mismatchAll']
        InsertionAll = chrDict[chromosome]['InsertionAll']
        DeletionAll = chrDict[chromosome]['DeletionAll']
        identities = chrDict[chromosome]['identities']
        lengths = chrDict[chromosome]['lengths']
        readLengths = chrDict[chromosome]['readLengths']
        if len(identities)>100:
            sys.stderr.write(chromosome+' '+str(np.mean(identities))+' '+str(np.median(identities))+' '+str(np.percentile(identities,75))+' '+str(np.percentile(identities,90))+' '+str(np.median(readLengths)) +' '+ str(len(identities)) + ' ' + str(len(not_aligned)) +' '+str(matchAll/denomAll)+ ' ' + str(mismatchAll/denomAll) + ' '+ str(InsertionAll/denomAll)+ ' '+ str(DeletionAll/denomAll)+ ' ' + str(len(identities))+'\t'+str(perfect/total)+'\n')
            string=''
            for i in range(0,50,1):
                if i in chrDict[chromosome]:
                     string+=str(chrDict[chromosome][i])+' '
                else:
                     string+='0 '
            print(string,min(identities))
    sys.stderr.write(str(total))
#        else:
#            sys.stderr.write(chromosome+'\n')
def parseCIGAR(cstr):
    '''
    Goes through the CIGAR string by pairing numbers with the letters.
    Counts the number of matches/mismatches and indels.
    Returns the total number of all, matches/mismatches (M), and indels.
    '''
    p = re.compile(r'([MIDNSHP=X])')
    splitCstr = [i+j for i,j in zip(p.split(cstr)[::2], p.split(cstr)[1::2])]
    total, M, insertion, deletion = 0, 0, 0, 0
    for i in range(len(splitCstr)):
        if splitCstr[i][-1] in 'MID':
            total += int(splitCstr[i][:-1])
        if splitCstr[i][-1] == 'M':
            M += int(splitCstr[i][:-1])
        if splitCstr[i][-1] == 'I':
            insertion += int(splitCstr[i][:-1])
        if splitCstr[i][-1] == 'D':
            deletion += int(splitCstr[i][:-1])
    return total, M, insertion, deletion

def main():
    inDir = sys.argv[1]
    # fileList = os.listdir(inDir)
    # fileList = [x for x in fileList if x[-4:] == '.sam']
    # for file in fileList:
    #     inFile = inDir + '/' + file
    #     sys.stderr.write(inFile + '\n')
    #     readSAM(inFile)
    readSAM(inDir)

main()
