#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys, os
import numpy as np
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

# error reporting function
def error(str, type=1):
    sys.stderr.write('ERROR: ' + str + '\n')
    exit(type)

# get input files
filelist = sys.argv[1:]

# input file checks
if (len(filelist) == 0):
    error('No input files given.')
for file in filelist:
    if (not os.path.isfile(file)):
        error(file + ' does not appear to be a file.')


lenlistlist = []
nreadslist = []
maxlenlist = []


# main loop over files
for file in filelist:

    try:
        fhandle = open(file)
    except:
        error('Failed to open file: ' + file)

    # data to gather
    nreads = 0
    lenlist = []
    maxlen = 0

    # loop over file
    seqnext = False
    for line in fhandle:

        l = line.strip()
        # determine line content
        if (l[:1] == '@'):
            seqnext = True
            continue
        elif (l[:1] == '+'):
            seqnext = False
            continue
        else:
            if seqnext:
                # this is a read sequence and not a quality string
                
                rlen = len(l)
                lenlist.append(rlen)
                if (maxlen < rlen):
                    maxlen = rlen
                nreads += 1
            else:
                # this is a quality string
                pass
    # loop over file done
    fhandle.close()

    # countslist = [0] * (maxlen+1)

    # for rlen in lenlist:
    # 	countslist[rlen] += 1

    # save data
    lenlistlist.append(lenlist)
    nreadslist.append(nreads)
    maxlenlist.append(maxlen)


####################
# plot stuff

outn = 'raw_readlength_' + \
    '_'.join([f.rsplit('/', 1)[-1].rsplit('.', 1)[0] for f in filelist])
totmax = max(maxlenlist)


#####
# use mRNA borders instead of annotated CDS regions

mRNA_names = ['mRNA 2', 'mRNA 3', 'mRNA 4', 'mRNA 5', 'mRNA 6', 'mRNA 7']
mRNA_juncends = [20555, 22150, 24048, 24586, 24977, 25665]
# mRNA 3 length: junc at 22150 => 27317 - 22150 + 60 Leader
mRNA_lengths = [6824, 5227, 3332, 2791, 2400, 1711]
# 40 nt polyA omitted
mRNA_lengths_noA = [l-40 for l in mRNA_lengths]
mRNA_colors = ['#3288BDFF', '#BEBEBEFF', '#D53E4FFF', '#F46D43FF', '#FEE08BFF', '#E6F598FF']




fig, axes = plt.subplots(figsize=(12, 6), nrows=2, sharex=True)
plt.subplots_adjust(hspace=0.08)
ax, ax2 = axes

ax.grid(alpha=0.4)
ax2.grid(alpha=0.4)


ticks = [x if x % 5000 == 0 or x == 8000 else '' for x in list(range(0, totmax+1000, 1000))]


# WT
plt.sca(ax)

plt.title('Raw read length histogram')
plt.hist(lenlistlist[0], list(range(0, totmax+1000, 10)), label='WT', log=True)
plt.xticks(list(range(0, totmax+1000, 1000)), ticks)
plt.xlim([-100, 27100])
# plt.ylabel('WT')


ylims = plt.ylim()
plt.vlines(mRNA_lengths_noA,
           ylims[0], ylims[1], linestyles='dotted', linewidth=0.7)
plt.ylim(ylims)
plt.legend(loc=1, framealpha=1)

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
color2 = colors[1]

# SL2
plt.sca(ax2)

plt.hist(lenlistlist[1], list(range(0, totmax+1000, 10)), label='SL2', log=True, color=color2)
plt.xticks(list(range(0, totmax+1000, 1000)), ticks)
plt.xlim([-100, 27100])
plt.ylabel('Occurences')
plt.xlabel('Raw read length - bars represent bins of size 10')
plt.legend(loc=1, framealpha=1)

# annot
ylims = plt.ylim()
plt.vlines(mRNA_lengths_noA,
           ylims[0], ylims[1], linestyles='dotted', linewidth=0.7)
plt.ylim(ylims)


for i, p in enumerate(mRNA_lengths_noA):
    plt.text(p, ylims[1], mRNA_names[i], ha='center', va='bottom',
             bbox=dict(facecolor=mRNA_colors[i], alpha=1))



out = outn + '.pdf'
plt.savefig(out, bbox_inches='tight')

# shrink
plt.sca(ax)
plt.xlim([-100, 8000])
plt.title('Raw read length histogram histogram - up to 8,000 nt')
plt.xticks(list(range(0, 9000, 1000)))
plt.savefig(out.replace('.pdf', '_upto8k.pdf'), bbox_inches='tight')

plt.close()
