#convert CDS for transcripts that have stop codons
#to run script: python3 ConvertCDS.py

import sys


#read gtf
def read_gtf():
    gtf_file = sys.argv[1]
    other_list = []
    CDS_stop_codon_dict = {}
    with open(gtf_file, 'r') as gtf:
        for line in gtf:
            new_line = line.split("\t")
            feature = new_line[2]
            gene_info = new_line[8].split("; ")
            for value in gene_info:
                if value.startswith("transcript_id"):
                    split_value = value.split(" ")
                    transcript_id = split_value[1].strip("\"")
            if feature == "gene":
                other_list.append(new_line)
            elif feature == "CDS":
                if transcript_id in CDS_stop_codon_dict:
                    CDS_stop_codon_dict[transcript_id].append(new_line)
                elif transcript_id not in CDS_stop_codon_dict:
                    CDS_stop_codon_dict.update({transcript_id:[new_line]})
            elif feature == "stop_codon":
                if transcript_id in CDS_stop_codon_dict:
                    CDS_stop_codon_dict[transcript_id].append(new_line)
                elif transcript_id not in CDS_stop_codon_dict:
                    CDS_stop_codon_dict.update({transcript_id:[new_line]})
            else:
                other_list.append(new_line)
    return other_list, CDS_stop_codon_dict


#convert CDS positions
def convert_CDS():
    other_list, cds_dict = read_gtf()
    final_cds_list =[]
    for isoform in cds_dict:
        single_isoform = cds_dict[isoform]
        features = []
        for s in single_isoform:
            features.append(s[2])
        set_list_features = list(set(features))
        if len(set_list_features) == 1:
            for feat in single_isoform:
                final_cds_list.append(feat)
        elif len(set_list_features) == 2:
            stop_codon_index = features.index("stop_codon")
            stop_codon_line = single_isoform[stop_codon_index]
            strand = stop_codon_line[6]
            if strand == "+":
                stop_codon_end_pos = int(stop_codon_line[4])
                for index, value in enumerate(single_isoform):
                    if index != stop_codon_index:
                        end_pos_value = int(value[4])
                        new_end_pos = end_pos_value - 3
                        if end_pos_value == stop_codon_end_pos:
                            final_cds_list.append(stop_codon_line)
                            final_cds_line = [value[0], value[1], value[2], value[3], str(new_end_pos), value[5], value[6], value[7], value[8]]
                            final_cds_list.append(final_cds_line)
                        else:
                            final_cds_list.append(value)
            elif strand == "-":
                stop_codon_end_pos = int(stop_codon_line[3])
                for index, value in enumerate(single_isoform):
                    if index != stop_codon_index:
                        start_pos_value = int(value[3])
                        new_start_pos = start_pos_value + 3
                        #end_pos_value = int(value[4])
                        #new_end_pos = end_pos_value + 3
                        if start_pos_value == stop_codon_end_pos:
                            final_cds_list.append(stop_codon_line)
                            final_cds_line = [value[0], value[1], value[2], str(new_start_pos), value[4], value[5], value[6], value[7], value[8]]
                            final_cds_list.append(final_cds_line)
                        else:
                            final_cds_list.append(value)
    return final_cds_list


#write all lines to output file
#will sort later
def write():
    cds_list = convert_CDS()
    other_lines, unused_dict = read_gtf()
    output = sys.argv[2]
    with open(output, 'a') as out:
        for value in other_lines:
            single_val = "\t".join(value)
            out.write(single_val)
        for val in cds_list:
            single_v = "\t".join(val)
            out.write(single_v)


write()
