#!/usr/bin/env python3
# Roger Volden

'''
Usage:
python3 mmSAM_to_identity.py minimap_output.sam >out
Takes a SAM file from minimap2 and returns percent identity for each read.

Updated usage:
changed to take a directory that contains sam files
python3 mmSAM_to_identity.py path/to/sam/files >out

Output format:
readName    identity (decimal)  # matches    # mismatches   # indels
'''

import sys
import re
import os
import numpy as np

def readSAM(inFile):
    '''
    Reads the sam file, parses MD and CIGAR strings, and prints the read
    name, percent identity, number of matches, number of mismatches,
    and the number of indels.
    '''
    used = set()
    identity_list=[]
    length_list=[]
    for line in open(inFile):
        if line.startswith('@'):
            continue
        line = line.rstrip().split('\t')
        name = line[0]
        if line[5] == '*' or name in used: # unaligned
            continue
        used.add(name)
        for column in line:
            if column.startswith('NM:i:'):
                NM = int(column[5:])
            if column.startswith('nn:i:'):
                ambig = int(column[5:])
        denominator, M, indel = parseCIGAR(line[5])
        mismatch = NM - indel - ambig
        matches = M - mismatch
        # print(denominator, M, indel)
        identity = matches / denominator
        length=matches+mismatch
        length_list.append(length)
        identity_list.append(identity)
        print(name + '\t' + str(identity) + '\t' + str(matches)
              + '\t' + str(mismatch) + '\t' + str(indel))
    print(np.median(identity_list),np.average(identity_list),np.median(length_list),np.average(length_list), file=sys.stderr)

def parseCIGAR(cstr):
    '''
    Goes through the CIGAR string by pairing numbers with the letters.
    Counts the number of matches/mismatches and indels.
    Returns the total number of all, matches/mismatches (M), and indels.
    '''
    p = re.compile(r'([MIDNSHP=X])')
    splitCstr = [i+j for i,j in zip(p.split(cstr)[::2], p.split(cstr)[1::2])]
    total, M, indel = 0, 0, 0
    for i in range(len(splitCstr)):
        if splitCstr[i][-1] in 'MID':
            total += int(splitCstr[i][:-1])
        if splitCstr[i][-1] == 'M':
            M += int(splitCstr[i][:-1])
        if splitCstr[i][-1] in 'ID':
            indel += int(splitCstr[i][:-1])
    return total, M, indel

def main():
    inDir = sys.argv[1]
    # fileList = os.listdir(inDir)
    # fileList = [x for x in fileList if x[-4:] == '.sam']
    # for file in fileList:
    #     inFile = inDir + '/' + file
    #     sys.stderr.write(inFile + '\n')
    #     readSAM(inFile)
    readSAM(inDir)

main()
