import csv
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import json

from padmet.classes import PadmetSpec
from cobra.io.sbml import read_sbml_model

from math import sqrt

sns.set('poster', rc={'figure.figsize':(15,15), 'lines.linewidth': 10})
sns.set_style("white", {'axes.grid' : False})

def compute_f_measure(input_rxns, ecocyc_rxns, uninferable_ecocyc_rxns):
    true_positive = len(ecocyc_rxns.intersection(input_rxns))
    true_negative = 0
    false_positive = len([rxn_id for rxn_id in input_rxns if rxn_id not in ecocyc_rxns])
    false_negative = len([rxn_id for rxn_id in ecocyc_rxns if rxn_id not in input_rxns])
    false_negative_uninferable_ecocyc_rxns = len([rxn_id for rxn_id in ecocyc_rxns if rxn_id not in input_rxns and  rxn_id not in uninferable_ecocyc_rxns])

    precision = true_positive / (true_positive+false_positive)
    recall = true_positive / (true_positive+false_negative)
    if false_positive != 0:
        fpr = false_positive / (false_positive+true_negative)
    else:
        fpr = 0
    if ((true_positive*true_negative)-(false_positive*false_negative)) != 0:
        mcc = ((true_positive*true_negative)-(false_positive*false_negative)) / sqrt((true_positive+false_positive)*(true_positive+false_negative)*(true_negative+false_positive)*(true_negative+false_negative))
    else:
        mcc = 0
    f_measure = (true_positive) / (true_positive+ (false_positive+false_negative)/2)

    f_measure_uninferable_ecocyc_rxns = (true_positive) / (true_positive+ (false_positive+false_negative_uninferable_ecocyc_rxns )/2)
    return true_positive, true_negative, false_positive, false_negative, false_negative_uninferable_ecocyc_rxns , precision, recall, fpr, mcc, f_measure, f_measure_uninferable_ecocyc_rxns

def parse_padmet(padmet_path):
    padmetSpec = PadmetSpec(padmet_path)

    ecs = [ec for node in padmetSpec.dicOfNode.values() if node.type == "reaction" and 'EC-NUMBER' in node.misc for ec in node.misc['EC-NUMBER']]

    reactions = [node for node in padmetSpec.dicOfNode.values() if node.type == "reaction"]
    reactions_with_genes = []
    reactions_without_genes = []
    for rxn_node in reactions:
        if any([rlt for rlt in padmetSpec.dicOfRelationIn[rxn_node.id] if rlt.type == "is_linked_to"]):
            reactions_with_genes.append(rxn_node.id)
        else:
            if 'SPONTANEOUS' not in rxn_node.misc:
                reactions_without_genes.append(rxn_node.id)
    spontaneous_reactions = [node.id for node in padmetSpec.dicOfNode.values() if node.type == "reaction" and 'SPONTANEOUS' in node.misc]

    all_rxns = [node for node in padmetSpec.dicOfNode.values() if node.type == "reaction"]
    total_pwy_id = set()
    for rxn_node in all_rxns:
        # Get all pathways having at least a reaction. Remove superpathways containing only pathways.
        pathways_ids = set([rlt.id_out for rlt in padmetSpec.dicOfRelationIn[rxn_node.id] if rlt.type == "is_in_pathway"])
        total_pwy_id.update(pathways_ids)

    return [ecs, set(ecs), len(reactions), len(reactions_with_genes), len(reactions_without_genes), len(spontaneous_reactions)]

def parse_gapseq_model(gapseq_sbml, gapseq_reactions_file, gapseq_pathways_file):
    sbml_1 = read_sbml_model(gapseq_sbml)
    reactions_with_genes = []
    reactions_without_genes = []
    all_reactions = []
    ecs = []
    for reaction in sbml_1.reactions:
        genes = [i.id for i in reaction.genes]
        if 'reaction.id' != 'biol':
            if len(genes) > 0 and 'spontaneous' not in genes:
                reactions_with_genes.append(reaction.id)
            if len(genes) == 0:
                reactions_without_genes.append(reaction.id)
            if 'ec-code' in reaction.annotation:
                if isinstance(reaction.annotation['ec-code'], list):
                    ecs.extend(reaction.annotation['ec-code'])
                else:
                    ecs.append(reaction.annotation['ec-code'])
            all_reactions.append(reaction.id)

    exhcanges_rxns = []
    reactions_enz = []
    for rxn in all_reactions:
        if 'biol' != rxn:
            if 'EX' in rxn:
                exhcanges_rxns.append(rxn)
            else:
                reactions_enz.append(rxn.split('_')[0])
    df_gapseq = pd.read_csv(gapseq_reactions_file, sep='\t', comment='#')
    keep_index = []

    for index, row in df_gapseq.iterrows():
        if len(set(row['dbhit'].split(' ')).intersection(set(reactions_enz))) > 0:
            keep_index.append(index)
    df_gapseq = df_gapseq.iloc[keep_index]
    df_gapseq = df_gapseq.replace(np.nan, '')
    gapseq_ecs = [ec for ecs in df_gapseq['ec'].str.split('/') for ec in ecs if ec != '']

    spontaneous_reactions = df_gapseq[df_gapseq['status'] == 'spontaneous']['rxn'].tolist()
    spontaneous_reactions = set(spontaneous_reactions).intersection(all_reactions)

    reactions_without_genes = set(reactions_without_genes) - set(exhcanges_rxns)

    df_gapseq_pathway = pd.read_csv(gapseq_pathways_file, sep='\t', comment='#')

    df_gapseq_pathway = df_gapseq_pathway[df_gapseq_pathway['Prediction'] == True]

    df_gapseq_pathway['ID'] = df_gapseq_pathway['ID'].str.replace('|', '', regex=False)
    gapseq_pathways = df_gapseq_pathway['ID'].tolist()

    return [ecs, set(ecs), len(sbml_1.reactions), len(reactions_with_genes), len(reactions_without_genes), len(spontaneous_reactions.union(set(exhcanges_rxns)))]

def parse_carveme_model(sbml_path):
    ecs = []
    sbml_1 = read_sbml_model(sbml_path)
    reactions_with_genes = []
    reactions_without_genes = []
    spontaneous_reactions = []
    for reaction in sbml_1.reactions:
        genes = [i.id for i in reaction.genes]
        if len(genes) > 0 and 'spontaneous' not in genes:
            reactions_with_genes.append(reaction.id)
        if len(genes) == 0:
            reactions_without_genes.append(reaction.id)
        if 'spontaneous' in genes:
            spontaneous_reactions.append(reaction.id)
        if 'ec-code' in reaction.annotation:
            if isinstance(reaction.annotation['ec-code'], list):
                ecs.extend(reaction.annotation['ec-code'])
            elif isinstance(reaction.annotation['ec-code'], str):
                ecs.append(reaction.annotation['ec-code'])

    return [ecs, set(ecs), len(sbml_1.reactions), len(reactions_with_genes), len(reactions_without_genes), len(spontaneous_reactions)]

def parse_modelseed_model(sbml_path, tsv_reactions_file, modelseed_mapping_ecs):
    ecs = []
    """
    for record in SeqIO.parse('KBase_derived_K_12_MG_1655_annot.gbff', 'genbank'):
        for feature in record.features:
            for qualifier in feature.qualifiers:
                if qualifier == 'EC_number':
                    ecs.extend(feature.qualifiers[qualifier])
    """
    sbml_1 = read_sbml_model(sbml_path)
    reactions_with_genes = []
    reactions_without_genes = []
    for reaction in sbml_1.reactions:
        genes = [i.id for i in reaction.genes]
        if len(genes) > 0 and 'spontaneous' not in genes:
            reactions_with_genes.append(reaction.id)
        if len(genes) == 0:
            reactions_without_genes.append(reaction.id)

    df_modelseed = pd.read_csv(tsv_reactions_file, sep='\t')
    df_modelseed = df_modelseed.replace(np.nan, '')
    modelseed_ecs = [ec for ecs in df_modelseed['enzyme'] if ecs != '' for ec in ecs.split('|')]
    modelseed_pathways = [pathway for pathways in df_modelseed['metacyc pathways'] if pathways != '' for pathway in pathways.split('|')]
    # Map reactions to EC from:
    # https://github.com/ModelSEED/ModelSEEDDatabase/blob/master/Biochemistry/Aliases/Unique_ModelSEED_Reaction_ECs.txt

    modelseed_rxns_ecs = [modelseed_mapping_ecs[reaction.id.split('_')[0]] for reaction in sbml_1.reactions if reaction.id.split('_')[0] in modelseed_mapping_ecs]
    modelseed_ecs += modelseed_rxns_ecs
    return [modelseed_ecs, set(modelseed_ecs), len(sbml_1.reactions), len(reactions_with_genes), len(reactions_without_genes), '']

padmet_path = 'networks_aucome/K_12_MG1655.padmet'
gapseq_sbml_path = 'networks_gapseq/K_12_MG1655/K_12_MG1655.xml'
gapseq_rxn_tsv_path = 'networks_gapseq/K_12_MG1655/K_12_MG1655-all-Reactions.tbl'
gapseq_pathway_tsv_path = 'networks_gapseq/K_12_MG1655/K_12_MG1655-all-Pathways.tbl'
carveme_path = 'networks_carveme/K_12_MG1655.sbml'
modelseed_path = 'networks_modeseed/sbml/genomeset__K_12_MG1655.gbk_genome.fbamodel.xml'
modelseed_tsv_path = 'networks_modeseed/tsv/genomeset__K_12_MG1655.gbk_genome.fbamodel-reactions.tsv'

df_modelseed_mapping = pd.read_csv('mapping_modelseed_ec.tsv', sep='\t')
df_modelseed_mapping.set_index('ModelSEED ID', inplace=True)
modelseed_mapping_ecs = df_modelseed_mapping['External ID'].to_dict()

output_folder = 'Figure_S6_output'
if not os.path.exists(output_folder):
    os.mkdir(output_folder)

aucome_met = parse_padmet(padmet_path)
aucome_ecs = aucome_met[1]
aucome_ecs = set([ec.replace('EC-', '')for ec in aucome_ecs])

gapseq_met = parse_gapseq_model(gapseq_sbml_path, gapseq_rxn_tsv_path, gapseq_pathway_tsv_path)
gapseq_ecs = gapseq_met[1]

carveme_met = parse_carveme_model(carveme_path)
carveme_ecs = carveme_met[1]

modelseed_met = parse_modelseed_model(modelseed_path, modelseed_tsv_path, modelseed_mapping_ecs)
modelseed_ecs = modelseed_met[1]

output_dir = 'output_comparison'

all_ecs = []
bigg_ref_ecs = []
# Get ECs from Bigg.
for json_file in os.listdir('jsons_bigg'):
    json_path = os.path.join('jsons_bigg', json_file)
    json_name, file_extension = os.path.splitext(json_file)
    with open(json_path) as json_file:
        json_data = json.loads(json_file.read())
    
    for reaction in json_data['reactions']:
        for annot in reaction['annotation']:
            if 'ec-code' == annot:
                ecs = reaction['annotation'][annot]
                all_ecs.extend(ecs)
                bigg_ref_ecs.extend(ecs)
modelseed_ref_ecs = []
# Get ECs from modelseed
for json_file in os.listdir('jsons_modelseed'):
    json_path = os.path.join('jsons_modelseed', json_file)
    json_name, file_extension = os.path.splitext(json_file)
    with open(json_path) as json_file:
        json_data = json.loads(json_file.read())

    for reaction in json_data['modelreactions']:
        reaction_id = reaction['id'].split('_')[0]
        if reaction_id in modelseed_mapping_ecs:
            ec = modelseed_mapping_ecs[reaction_id]
            all_ecs.append(ec)
            modelseed_ref_ecs.append(ec)

kegg_ecs = []
with open('kegg_ecs.txt', 'r') as input_file:
    for i in input_file.read().splitlines():
        all_ecs.append(i.split('\t')[0].replace('ec:', ''))
        kegg_ecs.append(i.split('\t')[0].replace('ec:', ''))

# Get ECs from ecocyc.
ecocyc_met = parse_padmet('ecocyc.padmet')
ecocyc_ecs = ecocyc_met[1]
ecocyc_ecs = set([ec.replace('EC-', '')for ec in ecocyc_ecs])
all_ecs.extend(ecocyc_ecs)

sets = [set(kegg_ecs), set(ecocyc_ecs), set(modelseed_ref_ecs), set(bigg_ref_ecs)]
all_ecs = set(all_ecs)


print(f'Union of ECs: {len(all_ecs)}')
# 1868 ECs
aucome_data = compute_f_measure(aucome_ecs, all_ecs, [])
gapseq_data = compute_f_measure(gapseq_ecs, all_ecs, [])
carveme_data = compute_f_measure(carveme_ecs, all_ecs, [])
modelseed_data = compute_f_measure(modelseed_ecs, all_ecs, [])

output_tsv = os.path.join(output_folder, 'Figure_S6_comparison_all.tsv')
with open(output_tsv, 'w') as output_file:
    csvwriter = csv.writer(output_file, delimiter='\t')
    csvwriter.writerow(['confusion_matrix', 'values', 'tools'])

    csvwriter.writerow(['True positives', carveme_data[0], 'CarveMe'])
    csvwriter.writerow(['False positives', carveme_data[2], 'CarveMe'])
    csvwriter.writerow(['False negatives', carveme_data[3], 'CarveMe'])

    csvwriter.writerow(['True positives', aucome_data[0], 'AuCoMe'])
    csvwriter.writerow(['False positives', aucome_data[2], 'AuCoMe'])
    csvwriter.writerow(['False negatives', aucome_data[3], 'AuCoMe'])

    csvwriter.writerow(['True positives', gapseq_data[0], 'gapseq'])
    csvwriter.writerow(['False positives', gapseq_data[2], 'gapseq'])
    csvwriter.writerow(['False negatives', gapseq_data[3], 'gapseq'])

    csvwriter.writerow(['True positives', modelseed_data[0], 'ModelSEED'])
    csvwriter.writerow(['False positives', modelseed_data[2], 'ModelSEED'])
    csvwriter.writerow(['False negatives', modelseed_data[3], 'ModelSEED'])

output_png = os.path.join(output_folder, 'Figure_S6_comparison_all.svg')

df = pd.read_csv(output_tsv , sep='\t')
g = sns.barplot(data=df, x="confusion_matrix", y="values", hue="tools")
plt.xlabel('')
plt.ylabel('Values')
g.set_ylim([0, 1300])
plt.savefig(output_png)
plt.clf()
