import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyfastx as px

GENOME = "./Neff_assembly.fa"
REPEAT = "./Neff_repeats.tsv"
N_CHROMS = 35

# Compute chromsizes
chroms = {chrom.name: len(chrom) for chrom in px.Fasta(GENOME)}
repeats = pd.read_csv(REPEAT, sep="\t")
# Get the relative genomic position of TTAGGG motifs (as percentage of chromlen)
repeats["dist_telo"] = repeats.apply(
    lambda r: (chroms[r.seqID] - r.start) / chroms[r.seqID], axis=1
)

# Only retain N longest scaffolds
long_chroms = sorted(chroms.items(), key=lambda x: int(x[1]))
long_chroms = [lc[0] for lc in long_chroms][::-1][:N_CHROMS]
repeats = repeats.loc[repeats.seqID.isin(long_chroms)]

# Convert chrom names to numeric id (longest is smallest)
repeats["seqnum"] = repeats.seqID.apply(lambda i: long_chroms.index(i))

# Discretize relative TTAGGG position (to get percentage)
bins = np.arange(0, 1.0, 0.01)
digitized = np.zeros((len(long_chroms), 101))
repeats["digitized"] = repeats.dist_telo.apply(
    lambda d: np.digitize(d, bins, right=True)
)
# For each chromosome, count the number of motif in each bucket (%)
# Make a heatmap: row is chrom, col is percentage bucket, color is # motifs
for i, chrom in enumerate(long_chroms):
    gr_idx, gr_counts = np.unique(
        repeats.loc[repeats.seqID == chrom, "digitized"],
        return_counts=True,
    )
    digitized[i, gr_idx] = gr_counts

fig, axes = plt.subplots(2, 1, sharex=True)
repeats.dist_telo.apply(lambda x: x * 100).hist(bins=100, ax=axes[0])
plt.suptitle("Chromosomal distribution of TTAGGG motifs")
axes[0].set_ylabel("Occurences of TTAGGG \n(Combined chroms.)")
im = axes[1].imshow(digitized, cmap="Blues")
axes[1].set_xlabel("Relative genomic position [%]")
axes[1].set_ylabel("Chromosome")
cb = fig.colorbar(
    im,
    ax=axes[1],
    shrink=0.6,
    location="bottom",
)
cb.set_label("Occurences of TTAGG")
plt.tight_layout()
plt.show()
