import os
import pandas as pd
import numpy as np
from scipy.stats import wilcoxon
import scipy.stats as stats 
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from scipy.stats import mannwhitneyu

# provide input path, this is only input required
rnafold_output_fasta_path = 'rnafold_fasta'
# output path
plots_path = '/home'
# plot format
plot_format = 'pdf'

meta_list = []
for file in os.listdir(rnafold_output_fasta_path):
    if file.endswith('.rnafold.fa') == True and file.startswith('._') == False:
        file_id = file.split('.rnafold.fa')[0]
        platform = file.split('_')[0]
        plotting_method = file.split('_')[1]
        dataset = file.split('_')[2]
        plot_name = f"{platform}_{plotting_method}"
        bin = file.split('_')[-1].split('.')[0]
        meta_list.append([file_id, platform, plotting_method,dataset, plot_name,bin])
meta_df = pd.DataFrame(meta_list)
meta_df.columns = ['file_id', 'platform', 'plotting_method','dataset','plot_name', 'bin']
meta_df = meta_df.sort_values(by='file_id')

for file in os.listdir(rnafold_output_fasta_path):
    if file.endswith('.rnafold.fa') == True and file.startswith('._') == False:
        file_id = file.replace('.rnafold.fa', '')
        file_path = f'{rnafold_output_fasta_path}/{file}'
        file_txt = open(file_path, 'r').read()
        for record in file_txt.split('>'):
            if record != '':
                record_index = str(record.split('\n')[0])
                record_name = f"record_{record_index}"
                energy = float(record.split('\n')[2].split('(')[-1].replace(')', ''))
                results_list.append([file_id, record_name, energy])
                
scores_df = pd.DataFrame(results_list)
scores_df.columns = ['file_id', 'record_name', 'energy']

# merge scores dataframe with metadata frame to use in plotting
df = scores_df.merge(meta_df,how='left')

pval_list = []
pval_list2 = []

# iterate over data of each plot (subset by plot_name), then iterate over each bin's data (subset by bin name). then get energy values for overlapped/pure and do wilcoxon test. 
for plot in df['plot_name'].unique():
    # subset plot data
    plot_df =  df[df['plot_name'] == plot]
    for bin in plot_df['bin'].unique():
        # subset bin data, and get energy values for overlapped/pure
        overlapped_vals = plot_df[(plot_df['bin'] == bin) & (plot_df['dataset'] == 'overlapped')]['energy']
        pure_vals = plot_df[(plot_df['bin'] == bin) & (plot_df['dataset'] == 'pure')]['energy']
        # Randomly sample overlapped_data to match the length of pure_data
        # overlapped_vals_sampled = overlapped_vals.sample(len(pure_vals), replace=True)
        _, pval = mannwhitneyu(pure_vals, overlapped_vals)
        pval_list.append([plot, bin, pval])
        

# make a dataframe for p-values
pval_df = pd.DataFrame(pval_list)
pval_df.columns = ['plot', 'bin', 'p-val']

dataset_order = ["overlapped", "pure"]

def blend_with_white(color, alpha):
    white = np.array([255, 255, 255])
    color = np.array(color)
    blended = (1 - alpha) * white + alpha * color
    return tuple(blended.astype(int) / 255)  #tuple(blended.astype(int))
    
dodgerblue_approx = blend_with_white((30, 144, 255), 0.7)
slategrey_approx = blend_with_white((112, 128, 144), 0.3)
tomato_approx = blend_with_white((255, 99, 71), 0.8)

custom_palette = {
    'overlapped': f'#{int(dodgerblue_approx[0]*255):02x}{int(dodgerblue_approx[1]*255):02x}{int(dodgerblue_approx[2]*255):02x}',
    'pure': f'#{int(slategrey_approx[0]*255):02x}{int(slategrey_approx[1]*255):02x}{int(slategrey_approx[2]*255):02x}'
}


def box_plot(data, plot_subset):
    # Subset energy scores data
    data = df[df['plot_name'] == plot_subset]
    # Subset p-values, and sort by bin to ensure order
    data_pval = pval_df[pval_df['plot'] == plot_subset].sort_values(by="bin") # p-values dataframe
    bins = list(data_pval["bin"]) # bins list
    pvals = list(data_pval["p-val"]) # p-values as list
    plt.figure(figsize=(10, 6))
    ax = sns.boxplot(data=data, x="bin", y="energy", hue="dataset",order=bins, hue_order=dataset_order ,showfliers=False, palette=custom_palette)

    for bin_index, pval in enumerate(pvals):
        ax.text(bin_index, 7, f"p-value = {'{:.2e}'.format(pval)}",  horizontalalignment='center', verticalalignment='center')
   
    custom_legend = [
        Patch(facecolor=dodgerblue_approx, label='ONT+Illumina', edgecolor='k', linewidth=0.6),
        Patch(facecolor=slategrey_approx, label='Illumina only', edgecolor='k', linewidth=0.6)
    ]
    ax.legend(handles=custom_legend, loc='lower right', fontsize=12)
    
    ax.set_ylim(top=9)
    ax.set_axisbelow(True)
    plt.xlabel("Distance from 3' end", fontsize=23)
    plt.ylabel("Minimum free energy (kcal/mol)", fontsize=19)
    plt.xticks(fontsize=20, rotation=30)
    plt.yticks(fontsize=18)
    plt.title(plot_subset, fontsize=18)
    plt.title("Minimum free energy between short and long-read", fontsize=23)
    #sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
    plt.grid(True)
    #plt.legend(loc='lower right')
    plt.tight_layout()
    output_plot = f"{plots_path}/{plot_subset}.{plot_format}"
    plt.savefig(output_plot,dpi=300)
    plt.show() 
