#!/usr/bin/env python3
import pdb
import sys
#import subprocess
import os
import glob
import argparse
import re
from collections import OrderedDict

from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq


def filter_filln_og(folder_path:str, species_list:str)->None:
    filter_log = open(f"{folder_path}/filter_filln.log","w")
    concat_fa = {}
    # read species list to list
    with open(species_list,"r") as f:
        prefix_ls = f.read().splitlines()
    for fa in glob.glob(f"{folder_path}/*fa*"):
        og_id = re.match(r"OG\d+",os.path.basename(fa))[0]
        fa_records = SeqIO.to_dict(SeqIO.parse(fa, "fasta"))
        fa_id = list(fa_records.keys()) # check the len of the alignment
        alignment_len = len(fa_records[fa_id[0]].seq)
        # check which species is missing
        sp_in_fa = [x.split("|")[0] for x in fa_id]
        # perfect case, don't need to process it
        if len(set(sp_in_fa)) == len(sp_in_fa) == len(prefix_ls):
            filter_log.write(f"{fa}\t{len(sp_in_fa)}\t{len(set(sp_in_fa))}\t.\n")
            continue
        # make sure there is really no duplicate prefix in the fa (can be for example, 2 species are missing and 1 species duplcate, seq still less than expected but we should throw it away)
        if len(set(sp_in_fa)) != len(sp_in_fa):
            print(f"{fa} has less then {len(prefix_ls)} but some ID is/are duplicated. Move to outgroup_not_SCO", file=sys.stderr)
            filter_log.write(f"{fa}\t{len(sp_in_fa)}\t{len(set(sp_in_fa))}\toutgroup_not_SCO\n")
            # move the fa to outgroup_not_SCO
            if not os.path.exists(f"{args.dir}/outgroup_not_SCO"):
                os.makedirs(f"{args.dir}/outgroup_not_SCO")
            os.rename(fa,f"{args.dir}/outgroup_not_SCO/{os.path.basename(fa)}")
            continue
        missing_sp = [x for x in prefix_ls if x not in sp_in_fa]        
        for i in missing_sp:
            # build a sed record for the missing species
            fa_records[f"{i}|{og_id}"] = SeqRecord(Seq("-"*alignment_len),id=f"{i}|{og_id}",description=f"{i}|{og_id}")

        # sort the seqrecord
        fa_records = OrderedDict(sorted(fa_records.items()))
        # write to new fasta
        path_without_ext,_ = os.path.splitext(fa)
        with open(f"{path_without_ext}_filln.fa", "w") as fahandle:
            SeqIO.write(fa_records.values(), fahandle, "fasta")
        
        # add to concat_fa
        for k,v in fa_records.items():
            sp = k.split("|")[0]
            if sp not in concat_fa:
                concat_fa[sp] = str(v.seq)
            else:
                concat_fa[sp] += str(v.seq)        
        # move the original fasta to a sub-directory
        if not os.path.exists(f"{args.dir}/fillN_origin_fastas"):
            os.makedirs(f"{args.dir}/fillN_origin_fastas")
        os.rename(fa,f"{args.dir}/fillN_origin_fastas/{os.path.basename(fa)}")

        # log
        filter_log.write(f"{fa}\t{len(sp_in_fa)}\t{len(set(sp_in_fa))}\t{path_without_ext}_filln.fa\n")

    with open(f"{args.dir}/concat.fa", "w") as concatfa_handle:
        for k,v in concat_fa.items():
            concatfa_handle.write(f">{k}\n")
            concatfa_handle.write(f"{v}\n")

if __name__ == "__main__":
    if len(sys.argv) == 1:
        exit(f"{sys.argv[0]} fastas_dir prefix_name_list.txt")

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", type = str, dest= "dir",help = "Fastas dir")
    parser.add_argument("-id", type = str, dest= "id",help = "Species prefix ID")

    args = parser.parse_args()

    filter_filln_og(folder_path=args.dir,species_list=args.id)

