#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# SK

import sys
import os
import re
import time
import pysam
from Bio import SeqIO
from statistics import median
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

def error(string, error_type=1):
    sys.stderr.write(f'ERROR: {string}\n')
    sys.exit(error_type)


def log(string, newline_before=False):
    if newline_before:
        sys.stderr.write('\n')
    sys.stderr.write(f'LOG: {string}\n')

#####

tombofile = sys.argv[1]
referencefile = sys.argv[2]
mappingfile = sys.argv[3]

#####

def read_reference(reference):

    chromosomes = {}
    ref = list(SeqIO.parse(reference, 'fasta'))
    for entry in ref:
        chromosomes[entry.id] = str(entry.seq)

    return chromosomes
    # return str(ref[0].seq) if (len(ref) == 1) else -1


def read_mapping(mapping):

    all_reads = int(pysam.view('-c', mapping))

    if mapping.endswith('bam'):
        in_file = pysam.Samfile(mapping, "rb")
    elif mapping.endswith('sam'):
        in_file = pysam.Samfile(mapping, "r")
    elif mapping.endswith('cram'):
        in_file = pysam.Samfile(mapping, "rc")
    else:
        error("No valid mapping format detected. File has to be SAM/BAM/CRAM.")

    iteration = in_file.fetch(until_eof=True)

    for read in iteration:
        yield (read, all_reads)


def read_tombo_result(tombo):

    meth = {}
    chromo = None
    with open(tombo) as infh:
        for line in infh:
            if line.startswith('track'):
                # skip header line
                continue
            
            if line.startswith('variableStep'):
                # get chromosome name
                lt = line.strip().split()
                chromo = lt[1].split('=')[1]
                assert chromo != ''
                assert chromo not in meth
                meth[chromo] = {}
                continue
            
            # normal line
            lt = line.strip().split()
            # 1-based (wig)  to  0-based (python)
            pos, methrate = int(lt[0])-1, float(lt[1])

            assert pos not in meth[chromo]
            meth[chromo][pos] = methrate

    return meth


def get_error_rates_and_subs_at_methyl_positions(read, reference_sequence, regex, methyl_positions, methyl_threshold):
    '''
    regex = re.compile(r'(\d+)([A-Z])') must be compiled once before
    '''

    read_sequence = str(read.query_sequence)
    
    # alignment length: like from blast identity - total bases in the alignment (including gaps (indels))
    # start with read.query_alignment_length (which is qend - qstart) and then add the deletions
    # this way we also ignore splice junctions
    alignment_length = read.query_alignment_length


    mapping_position = read.reference_start
    cigar_string = read.cigarstring


    pos_in_read = 0
    pos_in_ref = mapping_position


    # indel stats
    num_ins = 0
    num_del = 0
    # num_sub = 0
    len_ins = 0
    len_del = 0
    len_sub = 0

    # meth
    len_Csub = 0
    len_methsub = 0
    num_pos_not_in_tombo_results = 0

    subcounts = {}



    for match in regex.finditer(cigar_string):
        length, cigar_operation = match.groups()
        length = int(length)

        if cigar_operation == 'H':
            # hardclipped, skip
            pass

        elif cigar_operation == 'S':
            # softclipped
            pos_in_read += length

        elif cigar_operation == 'M':
            # match (or substitution), keep
            # log(f'-----\nMatch length {length}\npos in ref: {pos_in_ref}')

            # check amount of subs and if they are on meth positions
            for i, n in enumerate(read_sequence[pos_in_read:pos_in_read+length]):
                
                # log(f'Refbase {pos_in_ref}: {reference_sequence[pos_in_ref+i]}\t Readbase {i}: {n}')

                if n!=reference_sequence[pos_in_ref+i]:
                    # log('>>> Sub!')
                    len_sub += 1

                    # count this sub
                    if pos_in_ref+i in subcounts:
                        subcounts[pos_in_ref+i] += 1
                    else:
                        subcounts[pos_in_ref+i] = 1


                    if reference_sequence[pos_in_ref+i] == 'C':
                        # log('>>>>>> C subbed!')
                        len_Csub += 1

                        if pos_in_ref+i in methyl_positions:
                            if methyl_positions[pos_in_ref+i] >= methyl_threshold:
                                # log('>>>>>>>>> Methyl!')
                                len_methsub += 1
                        else:
                            num_pos_not_in_tombo_results += 1


            pos_in_read += length
            pos_in_ref += length

        elif cigar_operation == 'I':
            # insertion, skip
            pos_in_read += length

            num_ins += 1
            len_ins += length

        elif cigar_operation == 'D':
            # deletion, fill from reference
            if length >= 50:
                #log(f'long deletion: {length}')
                # LONG DELETION: assume mapping error
                # treat as N
                pos_in_ref += length

            else:
                pos_in_ref += length

                num_del += 1
                len_del += length

                # increase alignment length
                alignment_length += length

        elif cigar_operation == 'N':
            # splice junction, no correction
            pos_in_ref += length

        else:
            error(f'Unknown cigar operation: {cigar_operation}')


    return mapping_position, [len_ins, len_del, len_sub, alignment_length], [len_Csub, len_methsub, num_pos_not_in_tombo_results, subcounts]


###############


if __name__ == '__main__':

    meth = read_tombo_result(tombofile)
    # mRNA2 (S protein)
    methyl_pos_mRNA2 = meth['S']
    methyl_threshold = 0.9
    sign_meth_pos = [pos for pos in methyl_pos_mRNA2 if methyl_pos_mRNA2[pos] >= methyl_threshold]

    percent_C_meth = sum([1 if methyl_pos_mRNA2[pos]>=methyl_threshold else 0 for pos in methyl_pos_mRNA2]) / len(methyl_pos_mRNA2)


    chromos = read_reference(referencefile)

    regex = re.compile(r'(\d+)([A-Z])')

    ins_total = 0
    del_total = 0
    sub_total = 0

    # alignment length sum
    read_len_sum = 0


    # meth stats
    Csubs_total = 0
    methsubs_total = 0
    pos_not_in_tombo_total = 0

    subcounts_total = {}

    ctr = 0
    readset = set()
    for read, all_reads in read_mapping(mappingfile):

        ctr += 1
        if ctr == int(all_reads) or ctr % 1000 == 0:
            log(str(ctr)+' of '+str(all_reads)+' alignments parsed.')

        read_name = read.query_name

        # skip multiple alignments
        if read_name in readset:
            continue

        # check mapping target
        # only handle mappings to mRNA2
        if read.reference_name != '2':
            continue

        readset.add(read_name)


        # get reference sequence
        ref_sequence = chromos[read.reference_name]

        pos, stats, methsub = get_error_rates_and_subs_at_methyl_positions(read, ref_sequence, regex, methyl_pos_mRNA2, methyl_threshold)

        ins_total += stats[0]
        del_total += stats[1]
        sub_total += stats[2]
        read_len_sum += stats[3]

        Csubs_total += methsub[0]
        methsubs_total += methsub[1]
        pos_not_in_tombo_total += methsub[2]

        # subcounts
        subcounts = methsub[3]
        for pos in subcounts:
            if pos in subcounts_total:
                subcounts_total[pos] += subcounts[pos]
            else:
                subcounts_total[pos] = subcounts[pos]


    log(f'Number of reads:\t{len(readset)}')
    log(f'Total nucleotides:\t{read_len_sum}')
    log('')
    # log(f'Mean quality:\t{sum(avg_qual_list)/all_reads:.6f}')
    log(f'Insertions:\t{ins_total/read_len_sum*100:.6f}%\ttotal: {ins_total}')
    log(f'Deletions: \t{del_total/read_len_sum*100:.6f}%\ttotal: {del_total}')
    log(f'Substitut-s:\t{sub_total/read_len_sum*100:.6f}%\ttotal: {sub_total}')
    log(f'Total errors:\t{(ins_total + del_total + sub_total)/read_len_sum*100:.6f}%\ttotal: {ins_total + del_total + sub_total}')
    log('')
    # log(f'C-subs-all:\t{Csubs_total/read_len_sum*100:.6f}%')
    log(f'%-of-subs is C:\t{Csubs_total/sub_total*100:.6f}%\ttotal: {Csubs_total}')
    log('')
    # log(f'meth-%-subs:\t{methsubs_total/sub_total*100:.6f}%')
    log(f'%-of-Csub are meth:\t{methsubs_total/Csubs_total*100:.6f}%\ttotal: {methsubs_total}')
    log(f'Tombo: %-of-C meth:\t{percent_C_meth*100:.6f}%')
    log('')
    log(f'%Csub not in tombo res:\t{pos_not_in_tombo_total/Csubs_total*100:.6f}%\ttotal: {pos_not_in_tombo_total}')
    log('')

    subs_on_sign_meth = 0
    other_subs = 0
    for pos in subcounts_total:

        if pos in sign_meth_pos:
            subs_on_sign_meth += subcounts_total[pos]
        else:
            other_subs += subcounts_total[pos]
    allsubs = subs_on_sign_meth + other_subs

    log(f'% all pos is meth:\t{len(sign_meth_pos)/len(chromos["2"])*100:.6f}%\ttotal: {len(sign_meth_pos)} / {len(chromos["2"])}')
    log(f'% sub on meth:\t{subs_on_sign_meth/allsubs*100:.6f}%\ttotal: {subs_on_sign_meth} / {allsubs}')


    log('Done.')