#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import os
import re
import pysam
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

# error reporting function
def error(string, error_type=1):
    sys.stderr.write('ERROR: ' + string + '\n')
    exit(error_type)


# get input files
filelist = sys.argv[1:]

# input file checks
if len(filelist) is 0:
    error('No input files given.')
for file in filelist:
    if not os.path.isfile(file):
        error(file + ' does not appear to be a file.')

lenlistlist = []
nreadslist = []
maxlenlist = []

regex = re.compile(r'(\d+)(N)')


# main loop over files
for mapping in filelist:

    if mapping.endswith('bam'):
        in_file = pysam.Samfile(mapping, "rb")
    elif mapping.endswith('sam'):
        in_file = pysam.Samfile(mapping, "r")
    elif mapping.endswith('cram'):
        in_file = pysam.Samfile(mapping, "rc")
    else:
        error("No valid mapping format detected. File has to be SAM/BAM/CRAM.")

    iteration = in_file.fetch(until_eof=True)

    # data to gather
    nreads = 0
    lenlist = []
    maxlen = 0
    short_Ns = []

    for read in iteration:

        mappos = read.reference_start
        mapend = read.reference_end
        cigar = read.cigarstring

        alen = mapend-mappos+1

        # substract splice gaps
        # regex = re.compile(r'(\d+)(N)')
        for match in regex.finditer(cigar):
            length, cigop = match.groups()
            length = int(length)
            if length > 300:
                alen -= length
            else:
                short_Ns.append(length)

        # report
        nreads += 1
        lenlist.append(alen)
        maxlen = alen if alen > maxlen else maxlen

    # save data
    lenlistlist.append(lenlist)
    nreadslist.append(nreads)
    maxlenlist.append(maxlen)

sys.stderr.write(
    f'short_Ns <= 300: {len(short_Ns)} - mean length: {sum(short_Ns)/len(short_Ns)}')

#####
# annot

total_length = 27317
# gene_names = ['1a', '1ab', 'S', '4a', '4b', 'E', 'M', 'N']
# gene_borders = [(293, 12550), (293, 20568), (20570, 24091), (24091, 24492),
#                 (24482, 24748), (24750, 24983), (24995, 25672), (25686, 26855)]
# gene_colors = ['#ABDDA4FF', '#66C2A5FF', '#3288BDFF',
#                '#D53E4FFF', '#F46D43FF', '#FDAE61FF', '#FEE08BFF', '#E6F598FF']

# !!! 1ab->1b, no 4b, fix position of E to 24588

gene_names = ['1a', '1b', 'S', '4', 'E', 'M', 'N']
gene_borders = [(293,12550),(293,20568),(20570,24091),(24091,24587),(24588,24983),(24995,25672),(25686,26855)]
gene_colors = ['#ABDDA4FF', '#66C2A5FF', '#3288BDFF', '#D53E4FFF', '#F46D43FF', '#FEE08BFF', '#E6F598FF']


gene_starts = [g[0] for g in gene_borders]
gene_ends = [g[1] for g in gene_borders]


total_len_noA = 27276  # 41nt polyA omitted
# with 68nt leader sequence
cumulative_types_lens = [total_len_noA-E+68 for E in gene_starts[2:]]



#####
# use mRNA borders instead of annotated CDS regions

mRNA_names = ['mRNA 2', 'mRNA 3', 'mRNA 4', 'mRNA 5', 'mRNA 6', 'mRNA 7']
mRNA_juncends = [20555, 22150, 24048, 24586, 24977, 25665]
# mRNA 3 length: junc at 22150 => 27317 - 22150 + 60 Leader
mRNA_lengths = [6824, 5227, 3332, 2791, 2400, 1711]
# 40 nt polyA omitted
mRNA_lengths_noA = [l-40 for l in mRNA_lengths]
mRNA_colors = ['#3288BDFF', '#BEBEBEFF', '#D53E4FFF', '#F46D43FF', '#FEE08BFF', '#E6F598FF']





# plot stuff
# outn = file.rsplit('/',1)[-1].rsplit('.',1)[0]
outn = 'readlength_' + \
    '_'.join([f.rsplit('/', 1)[-1].rsplit('.', 1)[0] for f in filelist])
totmax = max(maxlenlist)


combolenlist = []
for lenlist in lenlistlist:
    combolenlist += lenlist


####################
# distrib
fig, axes = plt.subplots(figsize=(12, 6), nrows=2, sharex=True)
plt.subplots_adjust(hspace=0.08)
ax, ax2 = axes

ax.grid(alpha=0.4)
ax2.grid(alpha=0.4)

# stacked
# plt.hist(lenlistlist, list(range(0, totmax+1000, 10)),
#          label=[s.rsplit('/', 1)[-1].rsplit('.', 1)[0] for s in filelist], stacked=True, log=True)

ticks = [x if x % 5000 == 0 or x == 8000 else '' for x in list(range(0, totmax+1000, 1000))]


# WT
plt.sca(ax)

plt.title('Read alignment length histogram')
plt.hist(lenlistlist[0], list(range(0, totmax+1000, 10)), label='WT', log=True)
plt.xticks(list(range(0, totmax+1000, 1000)), ticks)
plt.xlim([-100, 27100])
# plt.ylabel('WT')


ylims = plt.ylim()
plt.vlines(mRNA_lengths_noA,
           ylims[0], ylims[1], linestyles='dotted', linewidth=0.7)
plt.ylim(ylims)
plt.legend(loc=1, framealpha=1)

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
color2 = colors[1]

# SL2
plt.sca(ax2)

plt.hist(lenlistlist[1], list(range(0, totmax+1000, 10)), label='SL2', log=True, color=color2)
plt.xticks(list(range(0, totmax+1000, 1000)), ticks)
plt.xlim([-100, 27100])
plt.ylabel('Occurences')
plt.xlabel('Read alignment length - bars represent bins of size 10')
plt.legend(loc=1, framealpha=1)

# annot
ylims = plt.ylim()
plt.vlines(mRNA_lengths_noA,
           ylims[0], ylims[1], linestyles='dotted', linewidth=0.7)
plt.ylim(ylims)


for i, p in enumerate(mRNA_lengths_noA):
    plt.text(p, ylims[1], mRNA_names[i], ha='center', va='bottom',
             bbox=dict(facecolor=mRNA_colors[i], alpha=1))



out = 'distrib_unstack_' + outn + '.pdf'
plt.savefig(out, bbox_inches='tight')

# shrink
plt.sca(ax)
plt.xlim([-100, 8000])
plt.title('Read alignment length histogram - up to 8,000 nt')
plt.xticks(list(range(0, 9000, 1000)))
plt.savefig(out.replace('.pdf', '_upto8k.pdf'), bbox_inches='tight')

plt.close()
