import pandas as pd
from collections import defaultdict
import polars as pl
import pyarrow as pa
from gtfparse import read_gtf
from tqdm import tqdm


class GTFProcessor:
    
    def __init__(self, gtf_file):
        self.df_original = read_gtf(gtf_file).to_pandas()

        
    @staticmethod
    def order_exons_by_start(exon_rows):
        """
        Order the exons based on the start position.
        """
        # Convert the start positions to integers and sort
        exon_rows.sort(key=lambda x: int(x['start']))
        return exon_rows

    def group_by_transcript(self):

        exon_groups = defaultdict(list)  # correctly grouped exons
        group_indices = []                                    
        id_of_last_transcript_feature_seen = ""
        mixed_lines = defaultdict(list)  # incorrectly grouped exons

        for index, row in tqdm(self.df_original.iterrows()):
            if row['feature'] == 'transcript':
                if id_of_last_transcript_feature_seen:
                    exon_groups[id_of_last_transcript_feature_seen].extend(group_indices)
                id_of_last_transcript_feature_seen = row.transcript_id
                group_indices = []
            if row['feature'] == 'exon': 
                if row.transcript_id == id_of_last_transcript_feature_seen:
                    group_indices.append(index)
                else:
                    mixed_lines[row.transcript_id].append(index)

        if group_indices:
            exon_groups[id_of_last_transcript_feature_seen].extend(group_indices)

        # Add row indexes of the transcripts in mixed_lines to their transcript in exon_groups 
        for key in mixed_lines:
            if key in exon_groups:
                exon_groups[key].extend(mixed_lines[key])

        # Reorder the rows in the DataFrame based on the correctly grouped exon rows
        reordered_rows = []

        for index, row in tqdm(self.df_original.iterrows()):
            if row.feature == 'gene':
                reordered_rows.append(row)
            if row.feature == 'transcript':
                reordered_rows.append(row)
                exon_rows_for_transcript = [self.df_original.iloc[idx] for idx in exon_groups[row.transcript_id]]
                ordered_exons = self.order_exons_by_start(exon_rows_for_transcript)
                reordered_rows.extend(ordered_exons)

        df_reordered = pd.DataFrame(reordered_rows).reset_index(drop=True)
        return df_reordered

    def convert_to_gtf(self, df_reordered, name):
        new_name = name.replace('.gtf','_sorted.gtf')
        def format_attributes(row):
            non_empty_attributes = [f'{col} "{row[col]}"' for col in row.index if col not in standard_columns and pd.notnull(row[col]) and row[col] != ""]
            return '; '.join(non_empty_attributes)
    
        standard_columns = ['seqname', 'source', 'feature', 'start', 'end', 'score', 'strand', 'frame']
        gtf_lines = []
    
        for _, row in df_reordered.iterrows():
            attributes = format_attributes(row)
            if attributes:  # Only add the attributes field if it's not empty
                gtf_line = '\t'.join([str(row[col]) for col in standard_columns])
                gtf_line += '\t' + attributes + ';'
            else:
                gtf_line = '\t'.join([str(row[col]) for col in standard_columns])
            gtf_lines.append(gtf_line)
    
        gtf_data = '\n'.join(gtf_lines) + '\n'
    
        # Write the file
        with open(f'{new_name}', 'w') as gtf_file:
            gtf_file.write(gtf_data)

        return new_name


