#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Mar  7 01:05:40 2021
@author: christospapadopoulos
"""

import sys,os
from Bio.Blast import NCBIXML
from Bio.Seq import Seq
#from Bio.Alphabet import IUPAC
from Bio import SeqIO
import subprocess

def read_multiFASTA(fasta_file):
    dico = {}
    with open(fasta_file,'r') as fasta:
        for line in fasta:
            if line.startswith('>'):
                name = str(line.split()[0])[1:]
                dico[name] = ''
            elif line == '\n':
                continue
            else:
                seq = line.strip()
                dico[name] = dico[name] + seq
    return(dico) 

def read_blast_denovo(blast_files_path):
    
    blast_denovo = {}

    for mega_blast in blast_files_path:

        sbj_organism = mega_blast.split("/")[-1].split(".")[0].split("-vs-")[1].split("MEGA")[0]
        
        print('\r\t' + 'BLAST reading for\t:\t' + sbj_organism, end = '')
        
        Scer_vs_Mega = open(mega_blast,"r")
        Scer_vs_Mega_records= NCBIXML.parse(Scer_vs_Mega)
        
        item = next(Scer_vs_Mega_records)
        while item != 'FINISH':
            al_count = 0
            
            if item.query not in genes_keep:
                try:
                    item=next(Scer_vs_Mega_records)
                except:
                    item='FINISH'
                continue

                
            if item.query not in blast_denovo:
                blast_denovo[item.query] = {}
            
            #print(sbj_organism , item.query,blast_denovo[item.query].keys())

            if item.alignments == []:
                blast_denovo[item.query][sbj_organism] = 'NaN'
                
                try:
                    item=next(Scer_vs_Mega_records)
                except:
                    item='FINISH'
                
                continue
                    
            for alignment in item.alignments:
                hsp_count = 0
                al_count += 1 
                for hsp in alignment.hsps:
                    hsp_count += 1
                    blast_denovo[item.query][sbj_organism] = {}
                    blast_denovo[item.query][sbj_organism]['hsp'] = hsp
                    blast_denovo[item.query][sbj_organism]['hit'] =  alignment.hit_def
                    if hsp_count == 1:
                        break
                if al_count == 1:
                    break
                
            try:
                item=next(Scer_vs_Mega_records)
            except:
                item='FINISH'
                
    return(blast_denovo)
        
def read_genes_to_keep(file):
    dico = {}
    with open(file,'r') as f:
        for x,line in enumerate(f):
            if x==0:
                continue
            dico[line.split()[0]] = {}
            dico[line.split()[0]]['Spar'] = line.split()[2]
            dico[line.split()[0]]['Smik'] = line.split()[3]
            dico[line.split()[0]]['Skud'] = line.split()[4]
            dico[line.split()[0]]['Sarb'] = line.split()[5]
            dico[line.split()[0]]['Sbay'] = line.split()[6]
    return(dico)
 


def extract_coordinates(name):
    strand = name.split('_')[1] 
    chrom  = name.split('_')[0]
    if strand == '+':
        start = min(name.split('_')[2].split('-'))
        stop  = max(name.split('_')[2].split('-'))
    elif strand == '-':
            start = min(name.split('_')[2].split('-'))
            stop  = max(name.split('_')[2].split('-'))
    return(chrom,strand,start,stop)

def extract_dna_seq(org,coord):
    dna_seq = genomes[org][coord[0]][int(coord[2])-1:int(coord[3])]
    if coord[1] == '-':
        dna_seq = str(Seq(dna_seq).reverse_complement())
    return(dna_seq)

def elongate_the_hit(sbjct):
    if len(sbjct)==4:
        coords = sbjct
    else:
        coords = extract_coordinates(name = sbjct)
    hit_seq_size = (int(coords[3]) - (int(coords[2])-1))/3
    positions_to_M    = (query_start + 1) * 3
    positions_to_STOP =((len(gene_seq)/3) - query_end)*3
    reduce_from_M    = sbjct_start * 3
    reduce_from_STOP = (hit_seq_size - sbjct_end)*3
    
    if coords[1] == '+':
        first_pos = (int(coords[2]) + reduce_from_M) - positions_to_M
        last_pos  = (int(coords[3]) - reduce_from_STOP) + positions_to_STOP
    elif coords[1] == '-':
        first_pos = (int(coords[2]) + reduce_from_STOP) - positions_to_STOP 
        last_pos  = (int(coords[3]) - reduce_from_M) + positions_to_M -1
        
    new_coords = (coords[0],coords[1],int(first_pos),int(last_pos))
    return(new_coords)

def create_Newick(gene):
    order = ['Scer','Spar','Smik','Skud','Sarb','Sbay']
    my_sp = list(sequences_dico[gene].keys())
    my_sp  = sorted(my_sp, key=lambda x: order.index(x))
    
    if len(my_sp) == 2:
        form = '({},{})'.format(*my_sp)
    elif len(my_sp) == 3:
        form = '(({},{}),{})'.format(*my_sp)
    elif len(my_sp) == 4:
        form = '((({},{}),{}),{})'.format(*my_sp)
    elif len(my_sp) == 5:
        form = '(((({},{}),{}),{}),{})'.format(*my_sp)
    elif len(my_sp) == 6:
        form = '((((({},{}),{}),{}),{}),{})'.format(*my_sp)

    return(form)    

def decide_the_ancestor(gene):
    order = ['Spar','Smik','Skud','Sarb','Sbay']
    my_sp = list(sequences_dico[gene].keys())[1:]
    my_sp  = sorted(my_sp, key=lambda x: order.index(x))
    for x,sp in enumerate(my_sp):
        if genes_keep[gene][sp] == 'IGORF' or genes_keep[gene][sp] == 'IGR':
            break
    return(x+1)

def elongate_the_borders(coord,size_nt):
    start = int(coord[2]) - size_nt
    stop  = int(coord[3]) + size_nt
    return((coord[0],coord[1],start,stop))
    

def extract_coordinates_for_IGR(hsp_tmp,hit):
    hit = hit.split("_")[0]+"_+_"+hit.split("_")[1]
    IGR_coords = extract_coordinates(name=hit)
    #gene_coords = extract_coordinates(name=gene)
    
    start = int(IGR_coords[2]) + (hsp_tmp.sbjct_start)
    end   = int(IGR_coords[2]) + (hsp_tmp.sbjct_end)
    
    coords = (IGR_coords[0],'+',start,end)
    if str(Seq(extract_dna_seq(org=sp,coord=coords)).translate()) == hsp_tmp.sbjct.replace("-",""):
        return(coords)
    elif str(Seq(extract_dna_seq(org=sp,coord=coords)).reverse_complement().translate()) == hsp_tmp.sbjct.replace("-",""):
        return(IGR_coords[0],'-',start,end)
    else:
        return(None)
    
path = sys.argv[sys.argv.index("-path")+1]
genes_keep = read_genes_to_keep(path+"/intermediate/TO_KEEP.tab")       

files = os.listdir(path + '/intermediate/denovo_vs_all_ORFs_more_than_12nt_xml/')
xmls = []
for i in files:
    if i.endswith("MEGA.xml"):
        xmls.append(path+"intermediate/denovo_vs_all_ORFs_more_than_12nt_xml/"+i)

all_blasts = read_blast_denovo(blast_files_path = xmls)

path_IGR = path + '/intermediate/denovo_vs_IGR_xml/'
files_IGR = os.listdir(path_IGR)
xmls_IGR = []
for i in files_IGR:
    if i.endswith("IGR.xml"):
        xmls_IGR.append(path_IGR+i)

all_blasts_IGR = read_blast_denovo(blast_files_path = xmls_IGR)
    

genomes_path = path + '/inputs/genomes/'
genomes = {}
for i in os.listdir(genomes_path):
    if i.endswith('.fai'):
        continue
    genomes[i.split('.')[0]] = read_multiFASTA(genomes_path + i)
  
 
    
sequences_dico = {}   
for gene in genes_keep:
    
    proc = subprocess.Popen(["grep {} {}/inputs/annotations/Scer.gff | grep CDS".format(gene,path)], stdout=subprocess.PIPE, shell=True)
    out = proc.communicate()[0]
    gene_name = out.decode().split()[0] + "_" + out.decode().split()[6] + "_" + out.decode().split()[3] + "-" + out.decode().split()[4]
    
    # Here I prologate the borders of the gene -100 and +100 nt
    prolog_coords = elongate_the_borders(coord=extract_coordinates(name=gene_name),size_nt=100)
    # And i take the nt sequence
    gene_seq = extract_dna_seq(org='Scer',coord = extract_coordinates(name=gene_name))
    gene_seq_prolonged = extract_dna_seq(org='Scer',coord = prolog_coords)

    sequences_dico[gene]={}
    sequences_dico[gene]['Scer'] = gene_seq_prolonged
    
    for sp in genes_keep[gene]:
        if genes_keep[gene][sp] == 'CDS':
#            print(all_blasts[gene].keys())
            sbjct  = all_blasts[gene][sp+"SUPER"]['hit']
            new_seq = extract_dna_seq(org=sp,coord=extract_coordinates(name=sbjct))
            new_seq_prolognged = extract_dna_seq(org=sp,coord=elongate_the_borders(coord=extract_coordinates(name=sbjct),size_nt=100))
            sequences_dico[gene][sp] = new_seq_prolognged
            print(gene,sp,len(sequences_dico[gene][sp]),len(gene_seq_prolonged))

        if genes_keep[gene][sp] == 'IGORF':
        
            sbjct  = all_blasts[gene][sp+"SUPER"]['hit']
            query_align = all_blasts[gene][sp+"SUPER"]['hsp'].query
            query_start = all_blasts[gene][sp+"SUPER"]['hsp'].query_start
            query_end   = all_blasts[gene][sp+"SUPER"]['hsp'].query_end
            sbjct_align = all_blasts[gene][sp+"SUPER"]['hsp'].sbjct
            sbjct_start = all_blasts[gene][sp+"SUPER"]['hsp'].sbjct_start
            sbjct_end   = all_blasts[gene][sp+"SUPER"]['hsp'].sbjct_end
            
            new_coord = elongate_the_hit(sbjct)
            prolongate_coord = elongate_the_borders(coord=new_coord,size_nt=100)
            new_seq = extract_dna_seq(org = sp,coord = new_coord)
            new_seq_prolognged = extract_dna_seq(org = sp,coord = prolongate_coord)
            sequences_dico[gene][sp] = new_seq_prolognged
            print(gene,sp,len(sequences_dico[gene][sp]),len(gene_seq_prolonged))
    
        if genes_keep[gene][sp] == 'IGR':
            imaginary_coords = extract_coordinates_for_IGR(hsp_tmp=all_blasts_IGR[gene][sp+'IGR']['hsp'],hit=all_blasts_IGR[gene][sp+'IGR']['hit'])
            query_start = all_blasts_IGR[gene][sp+'IGR']['hsp'].query_start
            query_end   = all_blasts_IGR[gene][sp+'IGR']['hsp'].query_end
            sbjct_start = 0
            sbjct_end   = len(all_blasts_IGR[gene][sp+'IGR']['hsp'].sbjct)
            new_coord   = elongate_the_hit(sbjct=imaginary_coords)
            prolongate_coord = elongate_the_borders(coord=new_coord,size_nt=100)
            new_seq = extract_dna_seq(org = sp,coord = new_coord)
            new_seq_prolognged = extract_dna_seq(org = sp,coord = prolongate_coord)
            sequences_dico[gene][sp] = new_seq_prolognged
            print(gene,sp,len(sequences_dico[gene][sp]),len(gene_seq_prolonged))
    

with open(path + '/intermediate/ancestors_to_keep.txt','w') as fanc , open (path + '/intermediate/New_relationships_list.tab','w') as wtab:       
    wtab.write('{:<30s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>10s}\n'.format('Gene','Scer','Spar','Smik','Skud','Sarb','Sbay','Ancestor'))
    for i in sequences_dico:
        if len(sequences_dico[i].keys()) == 1:
            continue
        newick = create_Newick(i)
        
        wtab.write('{:<30s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>8s}\t{:>10d}\n'.format(i,'CDS',genes_keep[i]['Spar'],genes_keep[i]['Smik'],genes_keep[i]['Skud'],genes_keep[i]['Sarb'],genes_keep[i]['Sbay'],decide_the_ancestor(i)))

        
        #print(decide_the_ancestor(i))
        with open(path + '/intermediate/FASTA_to_be_aligned/'+i+'.toali','w') as fali, open(path + '/intermediate/FASTA_to_be_aligned/'+i+'.newick','w') as fnew:
        
            fnew.write('{};'.format(newick))
            fanc.write('{}\t{}\n'.format(i,decide_the_ancestor(i)))
            for j in sequences_dico[i]:
                
                fali.write('>{}\n'.format(j))
                fali.write('{}\n'.format(sequences_dico[i][j]))
