#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May  2 14:05:38 2020

@author: christospapadopoulos
"""

from Bio import Phylo
import sys,os,time
#from Bio.SubsMat import MatrixInfo as matlist
import tools
from pyHCA import HCA
from pyHCA.core.annotateHCA import _annotation_aminoacids
from pyHCA.core.classHCA import compute_disstat



def get_ancestors(clades,ancestors,ancestor_to_keep):
    dico ={}
    #count_anc = 0
    if ancestor_to_keep == "ALL":
        for seq in ancestors:
            if seq.startswith('#'):
                pos = int(seq.split('#')[1]) - 1
                dico['anc_'+str(clades[pos])] = ancestors[seq].replace('-','')
            if seq.startswith('N'):
                pos = int(seq.split('N')[1]) - 1
                dico['anc_'+str(clades[pos])] = ancestors[seq].replace('-','')
                
    else:
        for seq in ancestors:
            if seq.startswith('#') and int(seq.split('#')[1]) == ancestor_to_keep:
                pos = int(seq.split('#')[1]) - 1
                dico['anc_'+str(clades[pos])] = ancestors[seq].replace('-','')
                break
            if seq.startswith('N') and int(seq.split('N')[1]) == ancestor_to_keep:
                pos = int(seq.split('N')[1]) - 1
                dico['anc_'+str(clades[pos])] = ancestors[seq].replace('-','')
                break
    return dico

def get_species(clades,ancestors,ancestor_to_keep):
    dico ={}
    #count_anc = 0
    if ancestor_to_keep == "ALL":
        print("For the moment I dont work with all the species")
        exit()
#        for seq in ancestors:
#            if seq.startswith('#'):
#                pos = int(seq.split('#')[1]) - 1
#                dico['anc_'+str(clades[pos])] = ancestors[seq].replace('-','')
    else:
        for seq in ancestors:
            if seq == str(clades[ancestor_to_keep-1]):
                dico[seq] = ancestors[seq].replace('-','')
                break
    return dico
        
def write_fasta(seqs,file):
    with open(file,'w') as fw:
        for i in seqs:
            fw.write('>{}\n'.format(i))
            fw.write('{}\n'.format(seqs[i]))


def get_hca_barcode(hca,orf):
    '''
    This module generates the HCA barcode of the total sequence
    Clusters of <= 4 residues are neglected
    '''
    barcode = "." * len(hca.get_seqbin(orf))
    clusters = hca.get_clusters(orf)
    for x,cluster in enumerate(clusters):
        cluster_elements = str(cluster).split('\t')
        if len(cluster_elements[-1]) > 4:
            barcode = barcode[:int(cluster_elements[1])-1] + cluster_elements[-1] + barcode[int(cluster_elements[2]):]
    return(barcode)
       
def calculate_HCA_barcodes(sequences):
    """
    We get the Barcode sequences of HCA for all the sequences
    """
    from pyHCA import HCA
    hca = HCA(seq=list(sequences.values()),querynames=list(sequences.keys()))
    # You can get the barcode sequence of ONE sequnce using the module get_hca_barcode
    # If you want to take the barcode of all the sequnces do a loop like:
    barcodes = {}
    for orf in list(sequences.keys()):
        barcodes[orf] = get_hca_barcode(hca = hca, orf = orf)
    return barcodes


try:
    fasta_file = sys.argv[sys.argv.index("-fasta")+1]
except:
    print('Not FASTA of the gene sequence found')
    exit()
    
try:
    fasta_anc = sys.argv[sys.argv.index("-ancs")+1]
except:
    print('Not FASTA of the ancestors sequences found')
    exit()
 
try:
    newick_file = sys.argv[sys.argv.index("-newick")+1]
except:
    print('Not Newick of the sequences found')
    exit()
    
try:
    ancestors_list = sys.argv[sys.argv.index("-anc_list")+1]
except:
    print('Not List of ancesotros found')
    print('I suppose you want ALL the ancestors')
    ancestors_list = "ALL"

try:
    analysis_type = sys.argv[sys.argv.index("-type")+1]
except:
    print('Not Analysis type found')
    exit()



# 1. Read the fasta of the gene sequence:
gene  = tools.read_multiFASTA(fasta_file = fasta_file)
name  = list(gene.keys())[0]
gseq  = list(gene.values())[0]
gsize = len(gseq)


#print("#################",name,"#############################")
# --------------------------------------- #

# 2. Read the newick file
newick = Phylo.read(file=newick_file,format='newick')
clades = newick.get_terminals()[1:]

# --------------------------------------- #

# 3. Read the ancestors nt sequences and associate their names
if ancestors_list == "ALL":
    ancestor_to_keep = "ALL"
else:
    ancestor_to_keep = int(tools.read_anc_list(ancestors_list)[name])
ancestors = tools.read_multiFASTA(fasta_file = fasta_anc)
if analysis_type == 'anc':
    ancestors = get_ancestors(clades,ancestors,ancestor_to_keep)
elif analysis_type == 'sp':
    ancestors = get_species(clades,ancestors,ancestor_to_keep)
# --------------------------------------- #


# 4. Run HCA for the gene sequence:
##########
    #cmd_hca = 'hcatk segment -i {} -o {} -m domain'.format(fasta_file , name+'.hca_tmp')
#os.system(cmd_hca)
#while not os.path.exists(name+'.hca_tmp'):
#    time.sleep(1)
#gene_HCA = tools.read_HCA(hca_file=name+'.hca_tmp')
##########
gene_score, pvalue = compute_disstat(0, len(gseq), _annotation_aminoacids(seq=gseq,method="domain",verbose=False)["cluster"])
gene_barcode = calculate_HCA_barcodes(gene)
gene_HCA = {}
gene_HCA[name] = {}
gene_HCA[name]["sequence"]  = gene_barcode[name]
gene_HCA[name]["HCA_score"] = gene_score



with open(name+'.frags_ali','w') as fali , open(name+'.frags','w') as wfrag , open(name+'.frags_nt','w') as wfrag_nt:
    
    for ancestor in ancestors:
        #print(ancestor)
        fragments , fragments_nt = tools.cut_into_peaces(sequence = ancestors[ancestor])
        
      
        fragments_barcode = calculate_HCA_barcodes(fragments)
        frag_HCA = {}
        for frag in fragments:
            frag_HCA[frag] = {}
            frag_score, pvalue  = compute_disstat(0, len(fragments[frag]), _annotation_aminoacids(seq=fragments[frag],method="domain",verbose=False)["cluster"])
            frag_HCA[frag]["sequence"] = fragments_barcode[frag]
            frag_HCA[frag]["HCA_score"] = frag_score
        
        write_fasta(seqs = fragments,file=name+'_'+ancestor+'.ancfrags_tmp')
        

        
        fali.write('{:>30s}\n'.format(ancestor))
        fali.write('{:>30s}\t:\t{}\n'.format(name,gseq))
        fali.write('{:>30s}\t:\t{}\n'.format(name+' ('+str(round(float(gene_HCA[name]['HCA_score']),2))+')',gene_HCA[name]['sequence']))
        
        wfrag.write('{}:\n'.format(ancestor))
        wfrag_nt.write('{}:\n'.format(ancestor))
        

        
        # Let's launch the Lalign and the HCA for the Fragments
        os.system("cp {} .".format(fasta_file))
        cmd = '/Users/christospapadopoulos/Documents/programs/fasta36-36.3.8/bin/lalign36 -m9 {}.pfasta {} > {}'.format(name,name+'_'+ancestor+'.ancfrags_tmp',name+'_'+ancestor+'.lalign_tmp')
        os.system(cmd)
        
       
        

        #cmd1 = './blastp -query {} -subject {} -outfmt 5 -task blastp-short -out {}'.format(fasta_file,name+'_'+ancestor+'.ancfrags_tmp',name+'_'+ancestor+'.xml_tmp')
        #os.system(cmd1)
        ######################
        #cmd2 = 'hcatk segment -i {} -o {} -m domain'.format(name+'_'+ancestor+'.ancfrags_tmp' , name+'_'+ancestor+'.ancfrags_HCA_tmp')
        #os.system(cmd2)
        
        #### Wait until HCA of fragments finished
        #while not os.path.exists('./'+name+'_'+ancestor+'.ancfrags_HCA_tmp'):
        #    time.sleep(1)
        ########################

        
        # Read HCA and BLAST of fragments:
        #frag_HCA = tools.read_HCA(hca_file=name+'_'+ancestor+'.ancfrags_HCA_tmp')
        LALIGN   = tools.parse_Lalign(lalign = name+'_'+ancestor+'.lalign_tmp',name=name)
        try:
            LALIGN   = tools.decide_fragments_aligned(LALIGN,gsize,frag_HCA)
        except:
            print('PROBLEM WITH LALIGN\t:\t',name)
            

        ############################
        #blast    = tools.read_blast(blast_file=name+'_'+ancestor+'.xml_tmp',frag_HCA=frag_HCA)
        
        #result = all(elem in list(blast[name].keys()) for elem in list(LALIGN.keys()))
        
        # NA DIORTHOSO TO GRAPSIMO KAI NA BALO OPTION GIA ANCESTOR selection h OLOI
        for frag in LALIGN:
            fali.write('{:>30s}\t:\t{}\n'.format(frag,LALIGN[frag]['align']))
            fali.write('{:>30s}\t:\t{}\n'.format(frag+' ('+str(round(float(frag_HCA[frag]['HCA_score']),2))+')',LALIGN[frag]['HCA_align']))
            #wfrag.write('>{} {}\n{}\n{}\n'.format(frag,round(float(frag_HCA[frag]['HCA_score']),2),fragments[frag],frag_HCA[frag]['sequence']))
            wfrag.write('>{}.{}|{}\n{}\n'.format(name,frag,round(float(frag_HCA[frag]['HCA_score']),2),fragments[frag]))
            wfrag_nt.write('>{}.{}|{}\n{}\n'.format(name,frag,round(float(frag_HCA[frag]['HCA_score']),2),fragments_nt[frag]))
            
            print(name,frag,ancestor,LALIGN[frag]['position'],\
                  round(float(gene_HCA[name]['HCA_score']),2),\
                  round(float(frag_HCA[frag]['HCA_score']),2),\
                  LALIGN[frag]['coverage'],LALIGN[frag]['frag_coverage'],\
                  gene_HCA[name]['sequence'],frag_HCA[frag]['sequence'],fragments[frag])
        wfrag.write('\n')  
        fali.write('\n')
        
        
        
        with open(name+'_'+ancestor+'.NONfrags','w') as nfw:
            for fr in fragments:
                if fr not in LALIGN:
                    nfw.write('{}\t{}\t{}\n'.format(fragments[fr],frag_HCA[fr]["sequence"],frag_HCA[fr]['HCA_score']))
            
            
        #########os.system('rm -fr '+name+'_'+ancestor+'.xml_tmp')
        os.system('rm -fr '+name+'_'+ancestor+'.lalign_tmp')
        os.system('rm -fr '+name+'_'+ancestor+'.ancfrags_HCA_tmp')
        os.system('rm -fr '+name+'_'+ancestor+'.ancfrags_tmp')
        os.system('rm -fr '+name+'.hca_tmp')
        os.system('rm -fr '+name+'.pfasta')
    
    
    
   
