#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import h5py
import matplotlib
matplotlib.use('Agg')  # do not require X window
import matplotlib.pyplot as plt

#####


def error(string, code=1):
    sys.stderr.write('ERROR: ' + string + '\n')
    exit(code)


def log(string):
    sys.stderr.write('LOG: ' + string + '\n')

#####


in_fast5 = sys.argv[1]
if not in_fast5.endswith('.fast5'):
    error(f'Input file does not seem to be .fast5 file: {in_fast5}')
out = sys.argv[2] if len(sys.argv) > 2 else in_fast5.rsplit('.', 1)[0]+'.pdf'

#####
# get data

fh = h5py.File(in_fast5, 'r')
try:
    for read in fh['Raw/Reads/']:
        signal = list(fh[f'Raw/Reads/{read}/Signal'])
        break  # should be only one
except:
    error(
        f'No raw signal ("Raw/Reads/<read>/Signal") found in file: {in_fast5}')

signal_rev = signal[::-1]
signal_len = len(signal)

read_start = fh['/Analyses/RawGenomeCorrected_000/BaseCalled_template/Events'].attrs['read_start_rel_to_raw']
read_start_rev = len(signal)-read_start

# channel data
sampling_rate = int(fh['/UniqueGlobalKey/channel_id'].attrs['sampling_rate'])
offset = int(fh['/UniqueGlobalKey/channel_id'].attrs['offset'])
digi_range = int(fh['/UniqueGlobalKey/channel_id'].attrs['range'])
digitisation = int(fh['/UniqueGlobalKey/channel_id'].attrs['digitisation'])

# pA converted signal
# y-axis: [current in pA] = (<signal_value> + [channel's 0pA adc - 'offset'] ) * [digitisable range in pA] / [digitisation]
scale = digi_range / digitisation
print(offset, scale)
signal_pA = [(x + offset) * scale for x in signal]


event_data = fh['/Analyses/RawGenomeCorrected_000/BaseCalled_template/Events']
event_starts_rev = [read_start_rev-e[2] for e in event_data]
event_lengths = [e[3] for e in event_data]
event_bases = [e[4].decode('UTF-8') for e in event_data]
# print(event_lengths)

bases_string = ''.join(event_bases)
print(bases_string)

str_of_3pUTR = 'ACATGCCACTGTGTTGTTTGAAATTCAGGCTTTAGTTGGAATTTTGCTTTTGTTCTTTCTTTTATTATCTTTCTTTTGCCTGTTTTTAGAGAGATTTGGCGCCTTGGTGCCGTAGATGAATACATTG'
start_5pUTR = bases_string.find(str_of_3pUTR)
print(f'Start of 5pUTR: {start_5pUTR}')

str_of_N = 'ATGGCTACAGTCAAATGGGCTGATGCATCTGAACCACAACGTGGTCGTCAGGGTAGAATACCTTATTCTCTTTATAGCCCTTTGCTTGTTGATAGTGAACAACCTTGGAAGGTGATACCTCGTA'
start_of_N = bases_string.find(str_of_N)
print(f'Start of N: {start_of_N}')

# last event (3p) of tombo resquiggle
tombo_last_3p = event_starts_rev[-1]-event_lengths[-1]

#####
# plot it

# fig, ax = plt.subplots(figsize=(16, 6))

fig, axes = plt.subplots(2, figsize=(16,6), gridspec_kw = {'height_ratios':[10, 0.01]})
# plt.subplots_adjust(hspace=0.2)
ax = axes[0]
ax2 = axes[1]
plt.sca(ax)

# ax.grid(alpha=0.5, linewidth=0.5)


#plt.plot(signal1, color='k')
#plt.plot(range(21000,31000), signal2, color='k')

# plt.xticks(event_starts_rev, '')
# plt.vlines(event_starts_rev, min(signal), max(signal), alpha=0.3)

# for i, e in enumerate(event_starts_rev):
# 	if int(e)<=12000:
# 		plt.text(e, 10, event_bases[i], ha='right')

ymin = 40
ymax = 185
plt.ylim([ymin, ymax])
plt.xlim([-10, len(signal)+10])

ax.set_title('Raw signal squiggle plot of one read')
ax.set_ylabel('Raw current [pA]')
ax.set_xlabel(f'Raw data point - sampling rate {int(sampling_rate)} Hz')

ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)


# positions
pos_polyA_end = 8540  # approximate, from looking at plot
pos_3pUTR_end = 10160  # approximate, adapted from tombo_last_3p and looking at plot
pos_N_end = event_starts_rev[start_5pUTR]
pos_leader_end = event_starts_rev[start_of_N]
pos_5pUTR_end = 70000  # arbitrary, check blast of full basecall
pos_5p_start = len(signal)

positions = [pos_polyA_end, pos_3pUTR_end,
             pos_N_end, pos_leader_end, pos_5p_start]
# polyA, 3pUTR, N, leader, 5pUTR
# annot_colors = ['#A4C8E4FF', '#7F7F7FFF', '#E6F598FF', '#FD9FA2FF', '#7F7F7FFF']
# polyA, 3pUTR, N, 5pUTR
annot_colors = ['#A4C8E4FF', '#7F7F7FFF', '#E6F598FF', '#7F7F7FFF']

# ax.vlines([read_start_rev, pos_3pUTR_end], min(signal), max(signal), color='red', linewidths=0.5)
# ax.vlines([pos_N_end, pos_leader_end], min(signal), max(signal), color='blue', linewidths=0.5)

# annot

for i in range(len(positions)):
    plt.axvspan(0 if i == 0 else positions[i-1], positions[i], ymax=1, color='green' if i == 0 else annot_colors[i-1], alpha=0.4, lw=0)
# plt.axvspan(positions[0], positions[-1], ymin=0.9, color='red', alpha=0.5, lw=0)

# ax.hlines([200, 150, 200, 150, 200], positions[:-1],
#           positions[1:], colors=annot_colors, linewidths=14)


# main squiggle
ax.plot(signal_pA, color='k', linewidth=1)


# text
ytop = ymax - 5
ax.text(positions[0]/2, ytop, 'Adapters', ha='center', va='center')
ax.text((positions[-1]+positions[0])/2, ytop, 'RNA molecule', ha='center', va='center')
ybot = ymin + 5
ax.text(sum(positions[0:2])/2, ybot, 'poly(A)', ha='center', va='center')
ax.text(sum(positions[1:3])/2, ybot, "3' UTR", ha='center', va='center')
ax.text(sum(positions[2:4])/2, ybot, 'N genomic region', ha='center', va='center')
ax.text(sum(positions[3:5])/2, ybot, "5' UTR", ha='center', va='center')


# time scale
plt.sca(ax2)
time_ticks = list(range(0, signal_len, sampling_rate))
time_tlabs = [str(n) if n%5 == 0 else '' for n in list(range(0, signal_len//sampling_rate+sampling_rate))]

plt.xlim([-10, len(signal)+10])
plt.xticks(time_ticks, time_tlabs)
ax2.set_xlabel(f'Raw data point (sampling rate {int(sampling_rate)} Hz)   /   Time [s]')

ax2.yaxis.set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.spines['right'].set_visible(False)

ax3 = ax2.twiny()
# ax3.set_xlabel(f'Raw data point - sampling rate {int(sampling_rate)} Hz')
plt.xlim([-10, len(signal)+10])

ax.xaxis.set_visible(False)

# plt.xlim([10100, 10200])
plt.savefig(out, bbox_inches='tight')
