#!/usr/bin/python
# -*- coding: utf-8 -*-
import argparse
import sys
import os, re
import logging
import pandas as pd
import numpy as np

from Bio import AlignIO
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord


#===================================================================================================
# inputs
#===================================================================================================

### Option defining
parser = argparse.ArgumentParser(prog="trimming_fasta.py",
                                 description='')
parser.add_argument('--version', action='version', version='%(prog)s 1.0')
parser.add_argument('--debug', action="store_true",
                    help="debug mode",
                    default=False)
##############

requiredOptions = parser.add_argument_group('REQUIRED OPTIONS')
requiredOptions.add_argument('-t', "--attribut_tsv", type=argparse.FileType('r'),
                             help='atribut table filename', required=True)
requiredOptions.add_argument("--outputR_dir", type=str,
                             help='outputR_dir', required=True)
requiredOptions.add_argument("--pep_dir", type=str,
                             help='pep_dir', required=True)
requiredOptions.add_argument('-o', '--output_dir', type=str,
                   help="Output dir name", required=True)
requiredOptions.add_argument('-l', '--letter', type=str,
                   help="witch letter is running", required=True)
##############


### Option parsing
args = parser.parse_args()

attribut_tsv = args.attribut_tsv
outputR_dir = args.outputR_dir
pep_dir = args.pep_dir
output_dir = args.output_dir
letter = args.letter



#===================================================================================================
# Set up output directory and logger
#===================================================================================================
### Set up the logger
# create logger
logger = logging.getLogger("trimming_fasta")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
if args.debug:
    ch.setLevel(logging.DEBUG)
else:
    ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter_ch = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter_ch)
logger.addHandler(ch)

logger.debug(sys.argv)

### Set up the output directory



output_dir_min3 = output_dir.replace("/","") + "_min3/"
if os.path.isdir(output_dir_min3):
    pass
elif output_dir_min3: # if OutDirName is not a empty string we create the directory
    os.makedirs(output_dir_min3)
    logger.debug("mkdir %s", output_dir_min3)


#===================================================================================================
#
#===================================================================================================

logger.info("Read input attributes table")

try:
    df_attribut_tsv = pd.read_csv(attribut_tsv, sep="\t", names = ["sp","pep","tableR"])
except Exception as exc:
    logger.error(str(exc))
    sys.exit(1)

logger.info(df_attribut_tsv)

if args.debug:
  df_attribut_tsv = df_attribut_tsv[0:5]

#===================================================================================================
# Build map dict (MGI -> Seq per sp)
#===================================================================================================

logger.info("Build map dict (MGI -> Seq per sp)")

df_table_R_sp_l = []
for sp in df_attribut_tsv["sp"]:
    logger.info("Start %s", sp)
    table_R_sp =  df_attribut_tsv["tableR"][df_attribut_tsv["sp"] == sp].values[0]
    logger.info("table_R_sp %s", table_R_sp)
    try:
        df_table_R_sp = pd.read_csv(outputR_dir+"/"+table_R_sp, sep="\t")
    except Exception as exc:
        logger.error(str(exc))
        sys.exit(1)
    df_table_R_sp.rename(columns={"ContigID":sp}, inplace=True)
    df_table_R_sp_l.append(df_table_R_sp[["MGI_ClusterID",sp]])

df_total_transcrit = reduce(lambda x, y: pd.merge(x, y, on = 'MGI_ClusterID', how='outer'), df_table_R_sp_l)

logger.info(df_table_R_sp_l)

no_empty_df = [df for df in df_table_R_sp_l if not df.empty ]
if not no_empty_df:
    logger.warning("All output R are empty")
    sys.exit(3)

logger.info(df_total_transcrit.head())

#===================================================================================================
# Parse fasta per Sp
#===================================================================================================
logger.info("Parse fasta per Sp")


dict_seq_sp_per_sp = {}
for sp in df_attribut_tsv["sp"]:
    logger.info("Start %s", sp)
    pep_f = df_attribut_tsv["pep"][df_attribut_tsv["sp"] == sp].values[0]
    records_sp = SeqIO.parse(pep_dir+"/"+pep_f, "fasta")
    dict_seq_sp = SeqIO.to_dict(records_sp)
   # for key in dict_seq_sp.keys():
     #   newkey=key.split(".p")[0]
     #   dict_seq_sp[newkey] = dict_seq_sp.pop(key)
    dict_seq_sp_per_sp[sp] = dict_seq_sp


#===================================================================================================
# create table
#===================================================================================================
logger.info("Parse fasta per Sp for table ")

list_table4conca=[]
for sp in df_attribut_tsv["sp"]:
    logger.info("Start %s", sp)


    table_R_sp =  df_attribut_tsv["tableR"][df_attribut_tsv["sp"] == sp].values[0]
    try:
        df_table_R_sp = pd.read_csv(outputR_dir+"/"+table_R_sp, sep="\t")
    except Exception as exc:
        logger.error(str(exc))
        sys.exit(1)


    contig_list=[]
    length_list=[]
    for contig in  dict_seq_sp_per_sp[sp].keys():

        contig_list.append(contig)
        length_list.append(len(dict_seq_sp_per_sp[sp][contig].seq))

    d = {'ContigID': contig_list , 'lenghtAA': length_list }
    lenght_AA = pd.DataFrame(data=d)
    table_conca=pd.merge( df_table_R_sp, lenght_AA ,  on="ContigID"  , how='left' )
    table_rename=table_conca.rename(index=str, columns={"ClusterID": 'ClusterID.{0}'.format(sp), "ContigID": 'ContigID.{0}'.format(sp), "qlen": 'qlen.{0}'.format(sp), "lenghtAA": 'lenghtAA.{0}'.format(sp) })

    list_table4conca.append(table_rename)




table_conca=reduce(lambda x, y: pd.merge(x, y, how='outer' ,  on = 'MGI_ClusterID'), list_table4conca)
    #table_conca.to_csv( "%s_table_longeur_transcrit_AA.tsv" %(sp) , sep="\t")
table_conca.to_csv( "%s_table_longeur_transcrit_AA.tsv"%(letter) , sep="\t")






#===================================================================================================
# Check contig presence in pep
#===================================================================================================

logger.info("Check contig presence in pep")

def check_contig_in_pep(x):
    mgi = x["MGI_ClusterID"]
    sp_l = [col for col in x.index if col not in ['MGI_ClusterID']] # get sp name to remove the MGI column
    dict_sp2contig = x[sp_l].to_dict()
    for sp,contig in dict_sp2contig.items():
        if not isinstance(contig, float): # is not NaN
            if not dict_seq_sp_per_sp[sp].has_key(contig):
                x[sp] = np.nan
    return(x)

df_total_transcrit_check_pep = df_total_transcrit.apply(check_contig_in_pep, axis=1)

#===================================================================================================
# Check contig length to  remove too short
#===================================================================================================

logger.info("Check contig pep length and remove too short  ")

def check_contig_length(x):
    mgi = x["MGI_ClusterID"]
    sp_l = [col for col in x.index if col not in ['MGI_ClusterID']] # get sp name to remove the MGI column
    dict_sp2contig = x[sp_l].to_dict()
    list_len=[]
    length_dict={}
    for sp,contig in dict_sp2contig.items():
        if not isinstance(contig, float): # is not NaN
            length_dict[sp]=len(dict_seq_sp_per_sp[sp][contig].seq)
            list_len.append(length_dict[sp])
    len_med=np.median(list_len)
    for sp,contig in dict_sp2contig.items():
        if not isinstance(contig, float): # is not NaN
			#######################################################################################################################################################
            if length_dict[sp] < len_med*0.7: #remove seq if length inferior than 0.7 of the median length of the MGI 
                x[sp] = np.nan
                logger.info("remove seq of  %s in %s", sp , mgi)

    return(x)


df_total_transcrit_check_len = df_total_transcrit_check_pep.apply(check_contig_length, axis=1)

#===================================================================================================
# check enough seq per MGI
#===================================================================================================

logger.info("Check enough seq per MGI")


def count_non_na(x):
    mgi = x["MGI_ClusterID"]
    sp_l = [col for col in x.index if col not in ['MGI_ClusterID']]

    dict_sp2contig = x[sp_l].to_dict()

    non_nan_count = 0
    for sp,contig in dict_sp2contig.items():
        if isinstance(contig, float): # is NaN
            pass
        else:
            non_nan_count+=1
    return(pd.Series({"MGI_ClusterID": mgi, "Count": non_nan_count}))

count_mgi_df = df_total_transcrit_check_len.apply(count_non_na, axis=1)
logger.info(df_total_transcrit_check_len)
logger.info(count_mgi_df)


nb_sp = df_attribut_tsv.shape[0]
logger.info("nb_sp: %s", nb_sp)



mgi_min3 = count_mgi_df["MGI_ClusterID"][count_mgi_df["Count"] > 2.0]




logger.info(count_mgi_df)
#logger.info(df_total_transcrit_check_pep)
logger.info(df_total_transcrit_check_len.shape)

df_total_transcrit_check_pep_check_min3_sp = df_total_transcrit_check_len[df_total_transcrit_check_len["MGI_ClusterID"].isin(mgi_min3)]



logger.info("min3_sp: %s",df_total_transcrit_check_pep_check_min3_sp.shape)


#===================================================================================================
# Create output files
#===================================================================================================

logger.info("Create output files")

def write_mgi_fasta(x, output_dir):
    mgi = x["MGI_ClusterID"]
    sp_l = [col for col in x.index if col not in ['MGI_ClusterID']]
    dict_sp2contig = x[sp_l].to_dict()
    logger.debug(x.to_dict())

    mgi_seq_l = []

    for sp,contig in dict_sp2contig.items():
        try:
            if not isinstance(contig, float): # is NaN
                seq=dict_seq_sp_per_sp[sp][contig].seq
                #ajoute initial dans l'id pour que se sois unique pour la suite du pipeline
                initial=sp.split('_')
                initial="".join([mot[:2].upper() for mot in initial])
                record = SeqRecord(seq,id=initial+"_"+contig,description=sp)
                #record = SeqRecord(seq,id=sp,description=sp)
                mgi_seq_l.append(record)
        except:
 
           logger.error("%s is not present in %s", contig, sp)
           sys.exit(42)

    SeqIO.write(mgi_seq_l, "%s/%s.fasta" %(output_dir,mgi), "fasta")

if True:


    df_total_transcrit_check_pep_check_min3_sp.apply(lambda x: write_mgi_fasta(x, output_dir_min3 ) , axis=1)



#===================================================================================================
# Create output files
#===================================================================================================


df_finale_stat_min3_sp = pd.merge(count_mgi_df[count_mgi_df["MGI_ClusterID"].isin(mgi_min3)], df_total_transcrit_check_pep_check_min3_sp, on = 'MGI_ClusterID', how='outer')
logger.info(df_total_transcrit_check_pep_check_min3_sp.shape)

df_finale_stat_min3_sp.to_csv("%s_out_table_min3_sp.tsv"  % letter, sep = "\t", index = False)
logger.info("Write summary output (min 3 )  in %s", "out_table_min3_sp.tsv")


