#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# SK

import sys
import numpy as np
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

def error(string, error_type=1):
    sys.stderr.write(f'ERROR: {string}\n')
    sys.exit(error_type)


def log(string, newline_before=False):
    if newline_before:
        sys.stderr.write('\n')
    sys.stderr.write(f'LOG: {string}\n')

#####


log('Reading in counts ...')

files = sys.argv[1:]

data = {}

for file in files:

    assert file not in data
    data[file] = {}

    with open(file) as infh:
        for line in infh:
            if line.startswith('#'):
                continue
            
            lt = line.strip().split('\t')
            mrna = lt[0]
            count = lt[2]
            mmiss = lt[4]

            assert mrna not in data[file]
            data[file][mrna] = {'count': int(count), 'mean_miss': float(mmiss)}



log('Plotting ...')

mrnas = ['genome', '2', '4', '5', '6', '7']
mnames = {m: ('genome' if m=='genome' else 'mRNA '+m) for m in mrnas}

aligned_WT = 74834
aligned_SL2 = 64988
# grep '^[^@]' new1_mRNAs_noA.sam | cut -f1 | sort | uniq | wc -l
# grep '^[^@]' new2_mRNAs_noA.sam | cut -f1 | sort | uniq | wc -l


labs = []
for mrna in mrnas:
    for file in files:
        labs.append(mnames[mrna])

posis = [i for i in range(len(mrnas)*3) if i%3!=2]


fig, axes = plt.subplots(figsize=(12, 6), ncols=1)
# fig.suptitle('mRNA counts for WT/SL2 samples')

# plt.sca(axes[0])
for fn, file in enumerate(files):
    pos = [i for i in range(len(mrnas)*3) if i%3==fn]
    pdata = [data[file][mrna]['count'] for mrna in mrnas]
    label = 'WT' if 'new1' in file else ('SL2' if 'new2' in file else 'ERROR')
    all_aligned = (aligned_WT if 'new1' in file else (aligned_SL2 if 'new2' in file else 'ERROR'))
    plt.bar(pos, pdata, width=0.8, label=label)
    for i, m in enumerate(mrnas):
        percent = pdata[i] / all_aligned * 100
        plt.text(pos[i], pdata[i], str(pdata[i])+f'\n{percent:.2f}%', va='bottom', ha='center')
plt.yscale('log')
plt.ylim( (plt.ylim()[0], plt.ylim()[1]*1.4) )
plt.ylabel('count')
plt.xticks(posis, labs, rotation=90)
plt.title('mRNA counts (max 40 missing nt)')
plt.legend(framealpha=1)
plt.grid(alpha=0.4)


# plt.sca(axes[1])
# for fn, file in enumerate(files):
#     pos = [i for i in range(len(mrnas)*3) if i%3==fn]
#     mdata = [data[file][mrna]['mean_miss'] for mrna in mrnas]
#     label = 'WT' if 'new1' in file else ('SL2' if 'new2' in file else 'ERROR')
#     plt.bar(pos, mdata, width=0.8, label=label)
# plt.ylabel('number of nucleotides')
# plt.xticks(posis, labs, rotation=90)
# plt.title('Mean number of missing nucleotides')
# plt.legend(framealpha=1)
# plt.grid(alpha=0.4)


plt.savefig('mrna_counts.pdf', bbox_inches='tight')
plt.savefig('mrna_counts.svg', bbox_inches='tight')
