#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import re
import sys

def error(string, error_type=1):
    sys.stderr.write(f'ERROR: {string}\n')
    sys.exit(error_type)


def log(string, newline_before=False):
    if newline_before:
        sys.stderr.write('\n')
    sys.stderr.write(f'LOG: {string}\n')

#####

class aligned_length:

    regex = re.compile(r'(\d+)([A-Z])')
    consume_ref_set = set(['M', 'D', 'N'])
    other_cigops = set(['S', 'I', 'H'])

    def get_aligned_length_from_cigar(self, cigar_string):
        '''
        parse cigar_string and return the sum of length of cigar operations that consume reference bases.
        '''
        aligned_length = 0

        for match in self.regex.finditer(cigar_string):
            length, cigar_operation = match.groups()
            length = int(length)

            if cigar_operation in self.consume_ref_set:

                aligned_length += length

            elif cigar_operation in self.other_cigops:

                pass

            else:

                error(f'Unknown cigar operation: {cigar_operation}')

        return aligned_length

    def get_softclip_length_from_cigar(self, cigar_string):
        '''
        check if softclip (S) is the first operation in cigar_string - if yes, return the length, otherwise return 0.
        '''
        
        softclip_length = 0

        # check first cigar operation
        match = self.regex.match(cigar_string)
        length, cigar_operation = match.groups()
        if cigar_operation == 'S':
            softclip_length = int(length)

        return softclip_length

#####

# read in alignment

align_len = aligned_length()

data = {}
reflens = {}

with open(sys.argv[1]) as infh:
    for line in infh:
        if line.startswith('@'):
            if line.startswith('@SQ'):
                lt = line.strip().split('\t')
                tag, refname = lt[1].split(':')
                assert tag == 'SN'
                tag, reflen = lt[2].split(':')
                assert tag == 'LN'
                
                assert refname not in reflens
                reflens[refname] = int(reflen)
                data[refname] = []
                continue
            else:
                # skip other header lines
                continue
        
        # handle normal alignment line
        # assume one alignment per read

        readname, flags, target, pos, _, cigar, _, _, _, seq = line.strip().split('\t')[:10]

        assert flags == '0'
        assert target in reflens

        alen = align_len.get_aligned_length_from_cigar(cigar)

        # position: 1-based to 0-based
        data[target].append((int(pos)-1, readname, target, seq, cigar, alen))

log('Read in alignment.')
for ref in reflens:
    log(f'{ref}:\t{len(data[ref])}')

#####

# parse alignments

complete = {}
missings = {}
frontmiss = {}
endmiss = {}
softclip = {}
longsclip = {}

# how many bases can miss from a read before it is considered incomplete
tolerance = 40


print(f'# HCoV mRNA analysis: {sys.argv[1]}')
print(f'# tolerance to be considered complete: <= {tolerance} missing bases')
print(f'# report reads that have >5 softclipped bases overhang upstream of mRNA start')
print('# mRNA_name\tmRNA_length\tcomplete_alignments\ttotal_alignments\tmean_missing_bases\tmean_missing_front\tmean_missing_end\tmean_front_softclip_length\tcount_softclip_overhang')

for ref in reflens:

    total = len(data[ref])
    reflen = reflens[ref]
    complete[ref] = 0
    missings[ref] = []
    frontmiss[ref] = []
    endmiss[ref] = []
    softclip[ref] = []
    longsclip[ref] = 0

    for read in data[ref]:

        # completeness
        missing = reflen - read[5]
        assert missing >= 0, str(read)

        if missing <= tolerance:

            # report reads that have >5 softclipped bases upstream of mRNA start
            sclip = align_len.get_softclip_length_from_cigar(read[4])
            if sclip - read[0] > 5:
                longsclip[ref] += 1
                # continue

            softclip[ref].append(sclip)

            complete[ref] += 1
            missings[ref].append(missing)

            # missing on front: alignment position on this mRNA
            frontm = read[0]
            frontmiss[ref].append(frontm)

            # missing on end:  reflen - (pos + alen)
            endm = reflen - (frontm + read[5])
            endmiss[ref].append(endm)

            # log(f'{(reflen, read[5], missing, frontm, endm, sclip)}')

    mean_missing =   sum(missings[ref]) / max(1, complete[ref])
    mean_frontmiss = sum(frontmiss[ref]) / max(1, complete[ref])
    mean_endmiss =   sum(endmiss[ref]) / max(1, complete[ref])
    mean_softclip =  sum(softclip[ref]) / max(1, complete[ref])

    # log(f'>>> mRNA: {ref} - length: {reflen}')
    # log(f'complete: {complete[ref]} / {total}')
    # log(f'missing avg: {mean_missing:.2f}')
    # log(f'frontms avg: {mean_frontmiss:.2f}')
    # log(f'endmiss avg: {mean_endmiss:.2f}')
    # log(f'softclip avg: {mean_softclip:.2f}')
    # log(f'sclip discrd: {longsclip[ref]}')

    # output for this reference sequence
    print(f'{ref}\t{reflen}\t{complete[ref]}\t{total}\t{mean_missing:.2f}\t{mean_frontmiss:.2f}\t{mean_endmiss:.2f}\t{mean_softclip:.2f}\t{longsclip[ref]}')


log('Done.')
