#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# fiveprime_errors.py - Sebastian Krautwurst


import sys
import os
import re
import time
import pysam
from Bio import SeqIO
from statistics import median
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

args = sys.argv

input_mapping = args[1]
input_reference = args[2]


def log(logmsg):
    sys.stderr.write('LOG ('+time.asctime()+') :: '+logmsg+'\n')

def error(string, code=1):
    sys.stderr.write('ERROR: ' + string + '\nExiting.\n')
    exit(code)

###############

def read_reference(reference):

    chromosomes = {}
    ref = list(SeqIO.parse(reference, 'fasta'))
    for entry in ref:
        chromosomes[entry.id] = str(entry.seq)

    return chromosomes
    # return str(ref[0].seq) if (len(ref) == 1) else -1


def read_mapping(mapping):

    # TODO: handle multiple alignments

    all_reads = int(pysam.view('-c', mapping))

    if mapping.endswith('bam'):
        in_file = pysam.Samfile(mapping, "rb")
    elif mapping.endswith('sam'):
        in_file = pysam.Samfile(mapping, "r")
    elif mapping.endswith('cram'):
        in_file = pysam.Samfile(mapping, "rc")
    else:
        error("No valid mapping format detected. File has to be SAM/BAM/CRAM.")

    iteration = in_file.fetch(until_eof=True)

    for read in iteration:
        yield (read, all_reads)


def get_error_rates_fiveprime(read, reference_sequence, regex, fiveprime_bases):
    '''
    regex = re.compile(r'(\d+)([A-Z])') must be compiled once before
    '''

    read_sequence = str(read.query_sequence)

    # alignment length: like from blast identity - total bases in the alignment (including gaps (indels))
    # start with read.query_alignment_length (which is qend - qstart) and then add the deletions
    # this way we also ignore splice junctions
    alignment_length = read.query_alignment_length

    if read.query_qualities != None:
        quality_values = list(read.query_qualities)
    else:
        log(f'{read}')
        log(f'No quality for read {read.query_name}')
        quality_values = [0]*len(read_sequence)
        
    mapping_position = read.reference_start
    cigar_string = read.cigarstring


    pos_in_read = 0
    pos_in_ref = mapping_position


    # indel stats
    num_ins = 0
    num_del = 0
    #num_sub = 0
    len_ins = 0
    len_del = 0
    len_sub = 0


    # softclipped part
    softclip = ''
    # corresponding reference part
    sc_ref = ''


    for match in regex.finditer(cigar_string):
        length, cigar_operation = match.groups()
        length = int(length)

        if cigar_operation == 'H':
            # hardclipped, skip
            pass

        elif cigar_operation == 'S':
            # softclipped

            if pos_in_read == 0:
                # only check fiveprime softclip
                softclip = read_sequence[pos_in_read : pos_in_read+length] #+ '_' + read_sequence[pos_in_read+length:pos_in_read+length+10]
                sc_ref = reference_sequence[max(0, pos_in_ref-length) : pos_in_ref] #+ '_' + reference_sequence[pos_in_ref:pos_in_ref+10]
                # pad on front
                if len(sc_ref) < len(softclip):
                    sc_ref = '.' * (len(softclip) - len(sc_ref)) + sc_ref

                # print('===========')
                # print(read_sequence[pos_in_read:pos_in_read+length+5])
                # print(softclip)
                # print(sc_ref)
                # print(reference_sequence[max(0, pos_in_ref-length) : pos_in_ref+5])

            pos_in_read += length

        elif cigar_operation == 'M':
            # match (or substitution), keep
            subs = [0 if n==reference_sequence[pos_in_ref+i] else 1 for i, n in enumerate(read_sequence[pos_in_read:pos_in_read+length])]
            len_sub += sum(subs)

            pos_in_read += length
            pos_in_ref += length

        elif cigar_operation == 'I':
            # insertion, skip
            pos_in_read += length

            num_ins += 1
            len_ins += length

        elif cigar_operation == 'D':
            # deletion, fill from reference
            if length >= 50:
                #log(f'long deletion: {length}')
                # LONG DELETION: assume mapping error
                # treat as N
                pos_in_ref += length

            else:
                pos_in_ref += length

                num_del += 1
                len_del += length

                # increase alignment length
                alignment_length += length

        elif cigar_operation == 'N':
            # splice junction, no correction
            pos_in_ref += length

        else:
            error(f'Unknown cigar operation: {cigar_operation}')


    return mapping_position, quality_values[:fiveprime_bases], [len_ins, len_del, len_sub, alignment_length], softclip, sc_ref


##################

if __name__ == '__main__':

    log('Started fiveprime_errors.py')
    log('Reading in reference.')
    chromos = read_reference(input_reference)

    regex = re.compile(r'(\d+)([A-Z])')

    fiveprime_bases = 30

    ins_total = 0
    del_total = 0
    sub_total = 0


    # alignment length sum
    read_len_sum = 0

    # mapping positions
    pos_list = []

    # fiveprime qualities
    quals_front_lists = [[] for i in range(fiveprime_bases)]

    # softclips
    softclip_lens = []

    log('Iterating over alignments/reads.')

    ctr = 0
    readset = set()
    for read, all_reads in read_mapping(input_mapping):

        ctr += 1
        if ctr == int(all_reads) or ctr % 1000 == 0:
            log(str(ctr)+' of '+str(all_reads)+' alignments parsed.')

        read_name = read.query_name

        # skip multiple alignments
        if read_name in readset:
            continue
        readset.add(read_name)

        # get reference sequence
        ref_sequence = chromos[read.reference_name]

        pos, quals_front, stats, softclip, sc_ref = get_error_rates_fiveprime(read, ref_sequence, regex, fiveprime_bases=fiveprime_bases)

        pos_list.append(pos)

        
        for i, q in enumerate(quals_front):
            quals_front_lists[i].append(q)

        softclip_lens.append(len(softclip))


        ins_total += stats[0]
        del_total += stats[1]
        sub_total += stats[2]
        read_len_sum += stats[3]


    # log(f'Total nucleotides:\t{read_len_sum}')
    # log(f'Mean quality:\t{sum(avg_qual_list)/all_reads:.6f}')
    # log(f'Insertions:\t{ins_total/read_len_sum*100:.6f}%')
    # log(f'Deletions:\t{del_total/read_len_sum*100:.6f}%')
    # log(f'Substitut-s:\t{sub_total/read_len_sum*100:.6f}%')
    # log(f'Total errors:\t{(ins_total + del_total + sub_total)/read_len_sum*100:.6f}%')

    log('Plotting ...')

    fig, axes = plt.subplots(figsize=(16, 6), ncols=3)

    sample = 'WT' if 'new1' in input_mapping else 'SL2'
    plt.suptitle(f"5' end error statistics ({sample} sample)")

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    color = colors[0] if sample=='WT' else colors[1]

    plt.sca(axes[0])
    medianprops = dict(color='black')
    labs = [p+1 for p in range(fiveprime_bases)]
    bp = plt.boxplot(quals_front_lists, labels=labs, showfliers=False, patch_artist=True, positions=labs, widths=0.9, medianprops=medianprops, whis=[5, 95])
    plt.xticks(ticks=labs, labels=[f'{b if b==1 or b%5==0 else ""}' for b in labs])
    plt.xlim((0, fiveprime_bases+1))
    plt.title('Base quality (by Albacore basecaller)')
    plt.ylabel('Base quality')
    plt.xlabel('Position in read')
    plt.grid(alpha=0.4)
    # plt.legend(loc=1, framealpha=1)

    for patch in bp['boxes']:
        patch.set_facecolor(color)


    plt.sca(axes[1])
    softclip_lens = [s for s in softclip_lens if s < fiveprime_bases+1]
    bins = list(range(fiveprime_bases+2))
    plt.hist(softclip_lens, bins=bins, align='left', color=color)
    plt.xticks(ticks=bins[:-1], labels=[f'{b if b==0 or b%5==0 else ""}' for b in bins[:-1]])
    plt.xlim((-1, fiveprime_bases+1))
    plt.xlabel('Number of softclipped bases')
    plt.ylabel('Occurences')
    plt.title(f'Softclipped bases histogram (up to {fiveprime_bases})')
    plt.grid(alpha=0.4)
    # plt.legend(loc=1, framealpha=1)


    plt.sca(axes[2])
    bins = range(fiveprime_bases)
    labs = [f'{ref_sequence[i].replace("T", "U")}\n{b+1 if b==0 or b%5==4 else ""}' for i, b in enumerate(bins)]
    plt.hist(pos_list, bins=bins, align='left', color=color)
    plt.xticks(ticks=bins, labels=labs)
    plt.xlim((-1, fiveprime_bases))
    plt.xlabel('Mapping position on reference genome')
    plt.ylabel('Occurences')
    plt.title(f'Mapping position on reference (up to {fiveprime_bases})')
    plt.grid(alpha=0.4)
    # plt.legend(loc=1, framealpha=1)


 


    plt.savefig(f'fiveprime_errors_{sample}.pdf', bbox_inches='tight')
    plt.savefig(f'fiveprime_errors_{sample}.png', bbox_inches='tight')


    log('All done!')

