#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# plot readlength and quality for sets of reads in two samples
# Usage:   restplot.py <reads1.fq> <set1a.ids> [<set1b.ids>]... <reads2.fq> <set2a.ids> [<set2b.ids>]...

import sys
from Bio import SeqIO
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

def error(string, error_type=1):
    sys.stderr.write(f'ERROR: {string}\n')
    sys.exit(error_type)


def log(string, newline_before=False):
    if newline_before:
        sys.stderr.write('\n')
    sys.stderr.write(f'LOG: {string}\n')

#####

readfiles = []
set1 = []
set2 = []
allfiles = []


for file in sys.argv[1:]:
    if not file.endswith('.ids'):
        assert file not in readfiles
        readfiles.append(file)
        allfiles.append(file)
    else:
        assert file not in allfiles
        allfiles.append(file)
        if len(readfiles) == 1:
            set1.append(file)
        else:
            set2.append(file)


log(f'Loading in read data: {readfiles}')
readdata = {}

for file in readfiles:
    for record in SeqIO.parse(file, "fastq"):
        readid = record.id
        rlen = len(record.seq)
        meanq = sum(record.letter_annotations["phred_quality"]) / rlen

        # disallow duplicates
        assert readid not in readdata
        readdata[readid] = (rlen, meanq)


log(f'Gathering data for read sets ...')
lens = {}
quals = {}

for rset in [set1, set2]:
    for file in rset:
        with open(file) as infh:

            log(f'Read set: {file}')
            assert file not in lens
            lens[file] = []
            assert file not in quals
            quals[file] = []

            for line in infh:
                # ignore blank lines
                if line == '\n':
                    continue
                readid = line.strip().split()[0]

                # assume data for this read is available
                rlen, qual = readdata[readid]
                lens[file].append(rlen)
                quals[file].append(qual)


log('Plotting ...')


order = []
for i in range(len(set1)):
    order.append(set1[i])
    order.append(set2[i])
print(order)
labs = [f.rsplit('/', 1)[-1].rsplit('.', 1)[0] for f in order]
print(labs)

ldata = [lens[file] for file in order]
qdata = [quals[file] for file in order]

posis = [i for i in range(len(order)*3//2) if i%3!=2]
medianprops = dict(color='black')

fig, axes = plt.subplots(figsize=(12, 6), ncols=2)

plt.sca(axes[0])
p1 = plt.boxplot(ldata, labels=labs, showfliers=False, patch_artist=True, positions=posis, widths=0.9, medianprops=medianprops, whis=[5, 95])
plt.xticks(rotation=90)
plt.title('Raw read lengths')
plt.grid(alpha=0.4)
# plt.legend(loc=1, framealpha=1)

plt.sca(axes[1])
p2 = plt.boxplot(qdata, labels=labs, showfliers=False, patch_artist=True, positions=posis, widths=0.9, medianprops=medianprops, whis=[5, 95])
plt.xticks(rotation=90)
plt.title('Mean quality (by Albacore basecaller)')
plt.grid(alpha=0.4)
# plt.legend(loc=1, framealpha=1)

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
colors = colors[:2] * len(order)
for bplot in (p1, p2):
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)


plt.savefig('restplot.pdf', bbox_inches='tight')
plt.savefig('restplot.png', bbox_inches='tight')
