import random
import sys
import csv
import os

from Bio import SeqIO
from padmet.classes.padmetSpec import PadmetSpec

records = []

if not os.path.exists('output'):
    os.mkdir('output')
if not os.path.exists('output/degraded'):
    os.mkdir('output/degraded')

eco_padmet = PadmetSpec('input/K_12_MG1655.padmet')
all_genes = [node.id for node in eco_padmet.dicOfNode.values() if node.type == "gene"]

removed = 0
cleaned = 0
unmodified = 0

def find_same_locus_tag(locus_tag, record):
    features = [index for index, feature in enumerate(record.features) if 'locus_tag' in feature.qualifiers if feature.qualifiers['locus_tag'][0] == locus_tag]
    return features

def remove_feature_annotation(feature):
    qualifier_to_removes = []
    for qualifier in feature.qualifiers:
        if qualifier != 'locus_tag':
            if qualifier != 'translation':
                qualifier_to_removes.append(qualifier)
    for qualifier_to_remove in qualifier_to_removes:
        del feature.qualifiers[qualifier_to_remove]

def degrade_genbank(genbank, new_genbank, all_genes, thresholds):
    records = []

    thresholds
    removed_threshold = thresholds[0]
    cleaned_threshold = thresholds[0] + thresholds[1]
    unmodified_threhsold = thresholds[2]

    if sum(thresholds) != 1:
        sys.exit('Sum of thresholds must be 1.')

    removed = 0
    cleaned = 0
    unmodified = 0

    for record in SeqIO.parse(genbank, 'genbank'):
        feature_to_removes = []
        feature_to_clean = []
        for index, feature in enumerate(record.features):
            # Degrade taxonomic annotation
            if feature.type in ['source']:
                feature.qualifiers['db_xref'][0] = 'taxon:131567' 
            if feature.type in ['CDS']:
                if 'locus_tag' in feature.qualifiers:
                    if feature.qualifiers['locus_tag'][0] in all_genes:
                        random_float = random.random()
                        if random_float < removed_threshold:
                            index_features = find_same_locus_tag(feature.qualifiers['locus_tag'][0], record)
                            feature_to_removes.extend(index_features)
                            removed += 1
                        elif removed_threshold <= random_float < cleaned_threshold :
                            index_features = find_same_locus_tag(feature.qualifiers['locus_tag'][0], record)
                            feature_to_clean.extend(index_features)
                            cleaned += 1
                        elif cleaned_threshold <= random_float:
                            unmodified += 1

        for feature_index in feature_to_clean:
            remove_feature_annotation(record.features[feature_index])

        for remove_feature_index in sorted(feature_to_removes, reverse=True):
            record.features.pop(remove_feature_index)

        records.append(record)

    print('Genes unmodified: ' + str(unmodified))
    print('Genes cleaned: ' + str(cleaned))
    print('Genes removed: ' + str(removed))

    SeqIO.write(records, new_genbank, 'genbank')

    return str(unmodified), str(cleaned), str(removed)

with open('output/threshold_degraded_data.tsv', 'w') as output_file:
    csvwriter = csv.writer(output_file, delimiter='\t')
    csvwriter.writerow(['named_run', 'Conserved', 'Remove annotation', 'Remove structural', 'conserved_genes', 'cleaned_genes', 'removed_genes'])
    with open('input/threshold.tsv', 'r') as input_file:
        csvreader = csv.reader(input_file, delimiter='\t')
        next(csvreader)
        for line in csvreader:
            print(line)
            unmodified, cleaned, removed = degrade_genbank('input/K_12_MG1655.gbk', 'output/degraded/' + line[0] + '_degraded_K_12_MG1655.gbk', all_genes, [float(line[3]),float(line[2]),float(line[1])])
            csvwriter.writerow([line[0], line[1], line[2], line[3], unmodified, cleaned, removed])
