#!/usr/bin/env python

import sys
print(sys.version)

import torch

import logging
# This setup allows logging in imported functions
from vcmsa.logger_module import logger

import pickle
from vcmsa.vcmsa_embed import parse_fasta_for_embed, get_embeddings
from vcmsa.vcmsa_utils import *
 

##printing for debugging with icecream
try:
    from icecream2 import ic
    ic.configureOutput(includeContext=True, outputFunction=print) ##prints line number and function
except ImportError:  # Graceful fallback if IceCream isn't installed.
    ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a)  # noqa

import argparse

from time import time

from scalene import scalene_profiler

# File IO
import os
import shutil
import sys

#import line_profiler
#import atexit
#profile = line_profiler.LineProfiler()

#atexit.register(profile.ic_stats)




#@profile
def get_align_args():

    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--in", dest = "fasta_path", type = str, required = True,
                        help="Path to fasta")
    
    parser.add_argument("-e", "--emb", dest = "embedding_path", type = str, required = False,
                        help="Path to embeddings")

    parser.add_argument("-eo", "--embed_only", dest = "embed_only", action='store_true',
                        help="If flagged, stop after saving embeddings (requires --emb_outfile)")

    parser.add_argument("-eout", "--emb_outfile", dest = "emb_outfile", type = str, required = False,
                        help="Path to save embeddings, optional. If not present, embeddings won't be saved")

    parser.add_argument("-o", "--outfile", dest = "out_path", type = str, required = True,
                        help="Path to outfile")

    parser.add_argument("-bc", "--batch_correct", dest = "batch_correct", action = "store_true", 
                        help="If added, do batch correction on sequences")
    parser.add_argument("-sl", "--seqlimit", dest = "seqlimit", type = int, required = False,
                        help="Limit to n sequences. For testing")

    parser.add_argument("-ex", "--exclude", dest = "exclude", action = "store_true",
                        help="Exclude outlier sequences from initial alignment process")


    parser.add_argument("-fx", "--fully_exclude", dest = "fully_exclude", action = "store_true",
                        help="Additionally exclude outlier sequences from final alignment")

    parser.add_argument("-l", "--layers", dest = "layers", nargs="+", type = int, default = [-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1],
                        help="Which layers (of 25 in protbert) to select default = -16 -15 -14 -13 -12 -11 -10 -9 -8 -7 -6 -5 -4 -3 -2 -1'")
    parser.add_argument("-hd", "--heads", dest = "heads", type = str,
                        help="File will one head identifier per line, format layer1_head3")

    parser.add_argument("-st", "--seqsimthresh", dest = "seqsimthresh",  type = float, required = False, default = 0.7,
                        help="Similarity threshold for clustering sequences, default = 0.7")

    parser.add_argument('--log', dest = "loglevel",  choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO')

    parser.add_argument("-m", "--model", dest = "model_name",  type = str, required = True,
                        help="Model name or path to local model")

    parser.add_argument("-co", "--cpu_only", dest = "cpu_only",  action = "store_true",
                        help="If --cpu_only flag is included, will run on cpu even if gpu available")


    parser.add_argument("-pca", "--pca_plot", dest = "pca_plot",  action = "store_true", required = False, 
                        help="If flagged, output 2D pca plot of amino acid clusters")

    parser.add_argument("-p", "--padding", dest = "padding",  type = int, required = False, default = 0, 
                        help="Number of characters of X to add to start and end of sequence (can be important for fragment sequences), default: 0")
    parser.add_argument("-mi", "--max_iterations", dest = "max_iterations",  type = int, required = False, default = 100, 
                        help="Number of iterations of limited search to run before increasing sequence threshold, default: 100")

    parser.add_argument("-l2", "--headnorm", dest = "headnorm",  action = "store_true", required = False, 
                        help="Take L2 normalization of each head")

    parser.add_argument("-mnc", "--maximum_noncrossing", dest = "mnc", action='store_true',
                        help="If flagged, perform maximum non-crossing")

    parser.add_argument("-wmnc", "--weighted-maximum_noncrossing", dest = "wmnc", action='store_true',
                        help="If flagged, perform weighted maximum non-crossing")

    parser.add_argument("-ro", "--rbh_only", dest = "rbh_only", action='store_true',
                        help="If flagged, stop after saving reciprocal best hit network (requires --rbh_outfile)")

    parser.add_argument("-rout", "--rbh_outfile", dest = "rbh_outfile", type = str, required = False,
                        help="Path to save reciprocal best hit network, optional. If not present, network won't be saved")

    parser.add_argument("-mis", "--maximum_increasing_subsequence", dest = "mis", action='store_true',
                        help="If flagged, perform maximum increasing subsequence")
    
    parser.add_argument("-bhf", "--best_hits_filtering", dest = "bhf",  type = str, required = False, default="rbh",
                        help='How to filter best hits, default:"rbh"')

    parser.add_argument("-so", "--sim_only", dest = "sim_only", action='store_true',
                        help="If flagged, stop after calculating sequence similarities (requires --sim_outfile)")  
    
    parser.add_argument("-sout", "--sim_outfile", dest = "sim_outfile", type = str, required = False,
                        help="Path to save sequence similarities, optional. If not present, sequence similarities won't be saved")

    parser.add_argument("-zc", "--zero_center", dest = "zero_center", action = "store_true", 
                        help="If added, do zero centering on sequences in batch correction, rather than running ComBat (requires --batch_correct)")

    parser.add_argument("-rb", "--ref_only", dest = "ref_only", action = "store_true", 
                        help="If added, add ref_only parameter to batch correction")



    parser.add_argument("-dc", "--do_clustering", dest = "do_clustering", action='store_true',
                        help="If flagged, does sequence clustering into subgroups")
    args = parser.parse_args()

    return(args)


 
#@profile
def get_seq_groups(seqs, seq_names, embedding_dict, exclude, do_clustering, outfile_name, record_dir, seqsim_thresh = 0.75, pca_plot = False):
    aa_embeddings = embedding_dict['aa_embeddings'] # this is numseqs x padded_seqlen x embedding_dim
    padded_seqlen = aa_embeddings.shape[1]


    logging.debug("Flattening hidden states list")
    hidden_states = np.array(reshape_flat(aa_embeddings)) # This is now numseqs*padded_seqlen x embedding_dim
    seqnums = list(range(0, len(seqs)))
    seqs_aas, seq_to_length = get_seqs_aas(seqs, seqnums)
    index_to_aa, hidden_states, seqnum_to_index, batch_list = remove_maxlen_padding(hidden_states, seqs_aas, padded_seqlen)
    if pca_plot:
       png_hidden_out = "{}/{}.initial.prebatch.png".format(record_dir, outfile_name)
       do_pca_plot(hidden_states, index_to_aa, png_hidden_out, seq_to_length = seq_to_length)
    

    sequence_embedding_list = []
    aa_embedding_list = []
    for i in range(len(seqs_aas)):
          seq_indices = seqnum_to_index[i]
          seq_i_aa_embeddings = np.take(hidden_states, seq_indices, axis = 0)
          aa_embedding_list.append(seq_i_aa_embeddings)
          #ic(seq_i_aa_embeddings.shape)
          seq_i_seq_embedding = np.mean(seq_i_aa_embeddings, axis = 0)
          #ic(seq_i_seq_embedding.shape)
          sequence_embedding_list.append(seq_i_seq_embedding)

    sequence_array = np.array(sequence_embedding_list, dtype="float32")
    numseqs = len(seqs)
    G, s_index = get_seqsims(sequence_array, k = numseqs)
    #print("printing sequence_array info: \n")
    #print(sequence_array)
    #print(len(sequence_array))
    #print(len(sequence_array[0]))
    #print("printing sequence_array: \n", sequence_array)
    to_exclude = []

   
    group_hstates_list = []
    cluster_seqnums_list = []
    cluster_names_list = []
    cluster_seqs_list = []
   

    logging.debug("Removing spaces from sequences")
    if do_clustering == True:
        #ic("fastgreedy")
        #ic(G)
    
      repeat = True
      while repeat == True:

        group_hstates_list = []
        cluster_seqnums_list = []
        cluster_names_list = []
        cluster_seqs_list = []
 
        prev_to_exclude = to_exclude
        

        outputs_list  = []
        #ic("GG", G.vs()['name'])
        #ic("GG", G.es()['weight'])
        edgelist = []
        weightlist = []
        #print("printing to_exclude in initial loop: \n", to_exclude)
        for edge in G.es():
             if G.vs[edge.target]["name"] not in to_exclude:
                  if G.vs[edge.source]["name"] not in to_exclude:
                     if edge['weight'] >= seqsim_thresh:
                             source = G.vs[edge.source]["name"]
                             target = G.vs[edge.target]["name"]
                        
 
 
                             length_diff_correction = 1 - abs(0.5 - len(seqs[source]) / (len(seqs[source]) + len(seqs[target])))
                             corrected_weight = edge['weight'] * length_diff_correction
                             logging.debug("seqsim: {} {} {} {} {} {}".format(source,target, edge['weight'], len(seqs[source]), len(seqs[target]), corrected_weight))
                             outputs = ",".join([str(source),str(target), str(edge['weight']), str(len(seqs[source])), str(len(seqs[target])), str(corrected_weight)])
                             outputs_list.append(outputs)
                             if corrected_weight >= seqsim_thresh:
                             
                                 edgelist.append([ source, target ])
                                 weightlist.append(corrected_weight)
        if sim_outfile:
            with open(sim_outfile, "w") as outfile:
                outfile.write("seq1,seq2,seqsim,seqlen1,seqlen2,correctedseqsim\n")
                for x in outputs_list:
                    outfile.write("{}\n".format(x))
        if sim_only:
           #print("Quiting after saving sequence similarities at {}".format(outfile))
            exit(0)

        # Rebuild G
        G = igraph.Graph.TupleList(edges=edgelist, directed=False)
        G.es['weight'] = weightlist

        G = G.simplify(combine_edges=max)
        islands = G.connected_components(mode = "weak")

        for island in islands:
             logging.debug("island {}".format(island)) 
        first_members = [x[0] for x in islands]
        # Artificially connect islands to avoid igraph error
        logging.debug(first_members)
        for x in first_members:
           for y in first_members:
              if x == y:
                continue
              else:
                G.add_edge(x, y, weight = 0.0000001)
       
        islands = G.connected_components(mode = "weak")
        for island in islands:
             logging.debug("island %s", island) 
        #ic("G", G)
        
        #if exclude == True:


        seq_clusters = G.community_walktrap(steps = 3, weights = 'weight').as_clustering() 
        if len(seq_clusters.subgraphs()) == len(G.vs()):
             seq_clusters = G.connected_components(mode = "weak") # walktrap can cluster nodes individually. See UBQ

        #ic("After walktrap", seq_clusters)
        for seq_cluster_G in seq_clusters.subgraphs():
        
                # Do exclusion within clusters
                #ic("seq_clusters", seq_cluster_G)
                if exclude == True:
    
                    clust_names = seq_cluster_G.vs()["name"]
                    #ic("clust_names", clust_names)
                    cluster_to_exclude = candidate_to_remove(seq_cluster_G, clust_names, z = -5)
                    #ic(cluster_to_exclude)
                       
                    to_delete_ids_sub_G = [v.index for v in seq_cluster_G.vs if v['name'] in cluster_to_exclude]
                    seq_cluster_G.delete_vertices(to_delete_ids_sub_G) 
    
    
                    #ic("to_exclude_pre", to_exclude)
                    to_exclude = to_exclude + cluster_to_exclude
                    to_exclude = list(set(to_exclude))
                    #ic("to_exclude_post", to_exclude)
                    if to_exclude:       
                        logging.debug("Excluding following sequences: {}".format(",".join([str(x) for x in to_exclude])))
                        #ic("Excluding following sequences: {}".format(",".join([str(x) for x in to_exclude])))
    
                hstates = []
                seq_cluster = seq_cluster_G.vs()['name']
                seq_cluster.sort()
                #ic(seq_cluster)
                cluster_seqnums_list.append(seq_cluster)
        
                filter_indices = seq_cluster
                group_hstates = np.take(embedding_dict['aa_embeddings'], filter_indices, axis = 0)
                group_hstates_list.append(group_hstates)
        
                cluster_names = [seq_names[i] for i in filter_indices]
                cluster_names_list.append(cluster_names)
           
                cluster_seq = [seqs[i] for i in filter_indices]
                cluster_seqs_list.append(cluster_seq)
                to_exclude = list(set(to_exclude))
        #ic("eq check", to_exclude, prev_to_exclude)
        if set(to_exclude) == set(prev_to_exclude):
           repeat = False
        else:
               cluster_seqs_list = [] 
               cluster_seqnums_list = []
               group_hstates_list = []
               cluster_names_list= []
    else:
         if exclude == True:
            clust_names = G.vs()["name"] 
            to_exclude = candidate_to_remove(G, clust_names, z = -3)
            #ic('name', to_exclude)
            to_delete_ids = [v.index for v in G.vs if v['name'] in to_exclude]
            #ic('vertix_id', to_delete_ids)
            G.delete_vertices(to_delete_ids) 
    
            logging.debug("Excluding following sequences: {}".format(",".join([str(x) for x in to_exclude])))
    
         else:
           logging.debug("Not removing outlier sequences")
           to_exclude = []
 
 
        # #ic([v['name'] for v in G.vs])
         cluster_seqnums_list =  [v['name'] for v in G.vs]
         #ic(cluster_seqnums_list, to_exclude)
         cluster_seqnums_list = list(set(cluster_seqnums_list))
         cluster_seqnums_list.sort()
         # Make sure this is removing to_exclude corectly
         cluster_seqs_list = [[seqs[i] for i in cluster_seqnums_list]]
         cluster_names_list = [[seq_names[i] for i in cluster_seqnums_list]]
         group_hstates_list = [np.take(embedding_dict['aa_embeddings'], cluster_seqnums_list, axis = 0)]
         cluster_seqnums_list = [cluster_seqnums_list] 
         to_exclude = list(set(to_exclude))

    #ic("seqnum clusters", cluster_seqnums_list)
    #ic(cluster_names_list)
    # Anything lost in the processes

    #ic("cluster_seqnums_list", cluster_seqnums_list) 
    dropped_seqs = [ x for x in list(range(0,len(seqs))) if x not in flatten(cluster_seqnums_list)]
    #ic("Dropped seqs:", dropped_seqs)
    to_exclude = list(set(to_exclude + dropped_seqs))    
    #ic("final to exclude", to_exclude)

    return(cluster_seqnums_list, cluster_seqs_list,  cluster_names_list, group_hstates_list, to_exclude)


#@profile
def do_msa(seqs, seq_names, seqnums, hstates_list,  minscore1 = 0.5, alignment_group = 0, args = None):
    """
    Required args: record_dir, outfile_name, do_pca, headnorm
    Control for running whole alignment process
    """
    fasta_path = args.fasta_path
    # outfile_path =  os.path.dirname(outfile)     
    # outfile_name =   os.path.splitext(os.path.basename(outfile))[0] # get outfile without extension
    pca_plot = args.pca_plot # default True
    headnorm = args.headnorm # default False
    record_dir = args.record_dir
    outfile_name = args.outfile_name
    batch_correct = args.batch_correct
    zero_center = args.zero_center
    ref_only = args.ref_only
    seqlens = [len(x) for x in seqs]
    max_iterations = args.max_iterations # default 100
    if len(seqlens) > max_iterations:
        max_iterations = len(seqlens)
 
    cpu_only = args.cpu_only 
    numseqs = len(seqs)
    padded_seqlen = hstates_list.shape[1]
    embedding_length = hstates_list.shape[2]

    rbh_outfile = args.rbh_outfile
    mis = args.mis
    mnc = args.mnc
    wmnc = args.wmnc
    rbh_only = args.rbh_only
    bhf = args.bhf

    logging.debug("Flattening hidden states list")
    hidden_states = np.array(reshape_flat(hstates_list), dtype="float32")
    logging.debug("embedding_shape: {}".format(hidden_states.shape))


    
    logging.debug("Convert index position to amino acid position")

    seqs_aas, seq_to_length = get_seqs_aas(seqs, seqnums)
    index_to_aa, hidden_states, seqnum_to_index, batch_list = remove_maxlen_padding(hidden_states, seqs_aas, padded_seqlen)
    if pca_plot:
        png_hidden_out = "{}/{}.alignment_group{}.prebatch.png".format(record_dir, outfile_name, alignment_group)
        do_pca_plot(hidden_states, index_to_aa, png_hidden_out, seq_to_length = seq_to_length)


    if batch_correct:
        #ic( list(range(len(seqs_aas)))) 
        hidden_states = do_batch_correct(hidden_states, list(range(len(seqs_aas))), batch_list, zero_center, ref_only)
        hidden_states = np.asarray(hidden_states, order='C') # Required for faiss as well as pca plot
        if pca_plot:
            png_hidden_out = "{}/{}.alignment_group{}.postbatch.png".format(record_dir, outfile_name, alignment_group)
            do_pca_plot(hidden_states, index_to_aa, png_hidden_out, seq_to_length = seq_to_length)

    logging.debug("maxlen padding removed")   
    #print("maxlen padding removed")

    if pca_plot:
        png_hidden_out = "{}/{}.alignment_group{}.postnorm.png".format(record_dir, outfile_name, alignment_group)
        do_pca_plot(hidden_states, index_to_aa, png_hidden_out, seq_to_length = seq_to_length)

    faiss.normalize_L2(hidden_states)



    index= build_index_flat(hidden_states, scoretype = "cosinesim", normalize_l2 = False) # Already normalized

    res = faiss.StandardGpuResources()  # use a single GPU
   #print("Cuda available?", torch.cuda.is_available())

    if torch.cuda.is_available():
       index = faiss.index_cpu_to_gpu(res, 0, index)
    else:
       logging.debug("KNN search significantly faster on a GPU")

    #print("index built")
    logging.debug("Index built") 

    start = time()
    k= numseqs * 20 
    if k > 2048:
       k = 2048 
    D1, I1 =  index.search(hidden_states, k = k) 
    #logging.debug("D1".format(D1))
    end = time()
    logging.debug("KNN done")
    start = time()
    I2 = split_distances_to_sequence(D1, I1, index_to_aa, numseqs, seqlens) 
    end = time()
    # I2 is a dictionary of dictionaries of each aa1: { 1: [[aa1:1.0], [aa2:0.8]}
    logging.debug("Split results into proteins done")

    logging.debug("get best hitlist")
    minscore1 = minscore1

    #hitlist_all = get_besthits(I2, minscore = 0) # No threshold for first clustering.
    # 
    #for x in hitlist_all:
    #   logging.debug("hitlist_all: {}".format(x))
    #hitlist_all = get_besthits(I2, minscore = 0.7) # No threshold for first clustering? Change to zero 
    #print(hitlist_all)
    #new_hitdict = get_besthits(I2, minscore = 0.5) #added threshold of 0.5
    new_hitdict = get_besthits(I2, minscore = 0.7)
    #print("printing new_hitdict: \n", new_hitdict)




    def process_hitlist(k_v):
    
        k, v = k_v
    
        if bhf == "rbh":
            new_rbh = get_rbhs(v)
    
        elif bhf == "no_rbh":
            new_rbh = v
    
        if mis == True:
            new_rbh = maximum_increasing(new_rbh)
    
        if wmnc == True:
            G = graph_from_hits_noncrossing(new_rbh, wmnc = wmnc)
    
        elif mnc == True:
            G = graph_from_hits_noncrossing(new_rbh)
    
        else:
            G = graph_from_rbh(new_rbh)
    
        return G
    
    full_parallel_rbhtime = time()
    
    with ThreadPoolExecutor() as executor:
        graphs = list(executor.map(process_hitlist, new_hitdict.items()))
    # Parallel reduces time by about a third
    logging.debug("Time full parallel rbh {}".format(time() - full_parallel_rbhtime))


    #full_rbhtime = time()
    #graphs = []
    #for k,v in new_hitdict.items():
    #    if bhf == "rbh":
    #        new_rbh = get_rbhs(v)
    #    elif bhf == "no_rbh":
    #        new_rbh = v
    #    if mis == True:
    #        new_rbh = maximum_increasing(new_rbh)
    #             
    #    #if mnc and wmnc flags are both on, perform weighted
    #    if wmnc == True:
    #        G = graph_from_hits_noncrossing(new_rbh, wmnc = wmnc)
    #        graphs.append(G) 
    #    elif mnc == True:
    #        G = graph_from_hits_noncrossing(new_rbh)
    #        graphs.append(G)
    #    else:
    #        G = graph_from_rbh(new_rbh)
    #        graphs.append(G)
#
    # logging.debug("Time full rbh {}".format(time() - full_rbhtime))

    #merge the graphs into one
    # This can be slow for large networks

    union_time = time()
    # potentially keep as rbh_list instead of G, to avoid igraph union slowness
    G = igraph.union(graphs, byname=True)
    # Can be half a second
    #print(f"Time for initial union {time() - union_time}")

    #print(G)
    #print(G.es['weight'])
    #add commandline argument to stop here
   #print('IN NEW VERSION')
   
    if rbh_outfile:
           #rbh_outfile = "tester.csv"
           with open(rbh_outfile, "w") as fOut:
                for e in G.es:
                   
                    source_name = G.vs[e.source]['name']
                    target_name = G.vs[e.target]['name']
                    #print(source_name, target_name, e['weight']) 
                    outstring = "{},{},{},{},{},{}\n".format(fasta_path, source_name.seqnum, target_name.seqnum, source_name, target_name, e['weight'])
                    fOut.write(outstring)
    if rbh_only:
        exit(0)


   
    log_level = logging.getLogger().getEffectiveLevel() 
    if logging.getLevelName(log_level) == "DEBUG":
       for e in G.es:
           source_name = G.vs[e.source]['name']
           target_name = G.vs[e.target]['name']
           logging.debug("initial rbh: {}, {}, {}".format(source_name.seqnum, target_name.seqnum, e['weight']))

    #print(end - start)


    logging.debug("got reciprocal best hits")
   

    #rbh_list = remove_streakbreakers(rbh_list, seqs_aas, seqlens, streakmin = 2)
    #print("new_len", len(rbh_list))
    

    # outnet = "{}/{}.testnet_initial_clustering{}.csv".format(record_dir, outfile_name, alignment_group)
    # with open(outnet, "w") as outfile:
    #       outfile.write("aa1,aa2,score\n")
    #       for x in rbh_list:
    #          outstring = "{},{},{}\n".format(x[0], x[1], x[2])        
    #          outfile.write(outstring)


    #ic("Start betweenness calculation to filter cluster-connecting amino acids. Also first round clustering")

    #G = graph_from_rbh(rbh_list, directed = False)
 
    all_alternates_dict = {}
    clusters_list = []
    
    #print("Begin with only natural clusters")
    #cluster_list1 = []
    #cluster_list2 = []
    #cluster_list3 = []
    #cluster_list4 = []

    if len(seqs) > 2:
        minclustsize = int(len(seqs)/2) + 1
        
        #MINCLUSTSIZE CHANGE TO 0
        #minclustsize=0

        #if len(clusters_list) == 0:
        #clusters_list = consistency_clustering(G, minclustsize = 0, dup_thresh = 1):
        #if len(clusters_list) == 0:
        #clusters_list, all_alternates_dict = first_clustering(G, betweenness_cutoff = 0.1, ignore_betweenness = False, apply_walktrap = True)
        #clusters_list = process_network(G, betweenness_cutoff = 0.1, apply_walktrap = True)
        #print(G)
        #cluster_list1, G = first_clustering2(G, natural = True)
        #print("cluster_list1", cluster_list1)
        #cluster_list2, G = first_clustering2(G, natural = True, remove_low_authority = True)
        #print("cluster_list2", cluster_list2)
        #print("after 2", G)
        #cluster_list3, G = first_clustering2(G, natural = False, remove_low_authority = True)
        #print("after 3", G)
        #print("cluster_list3", cluster_list3)
        #cluster_list4, G = first_clustering2(G, natural = False, remove_low_authority = True)
        #print("cluster_list4", cluster_list4)
    else:
        #MINCLUSTSIZE CHANGE TO 0
        #minclustsize=0
        minclustsize = 2
        #clusters_list = consistency_clustering(G, minclustsize = 0, dup_thresh = 1):
        #if len(clusters_list) == 0:
        #clusters_list, all_alternates_dict = first_clustering(G, betweenness_cutoff = 1, ignore_betweenness = True, apply_walktrap = True)
        # Progressively cluster and remove completed clusters from network. Continue to try and build clusters. Potentially only go to future steps if no natural clusters found
        #cluster_list1, G = first_clustering2(G, natural = True)
        #print("cluster_list1", cluster_list1)
        #cluster_list2, G = first_clustering2(G, natural = True, remove_low_authority = True)
        #print("cluster_list2", cluster_list2)
        #print("after 2", G)
        #cluster_list3, G = first_clustering2(G, natural = False, remove_low_authority = True)
        #print("cluster_list3", cluster_list3)
        #print("after 3", G)
        #cluster_list4, G = first_clustering2(G, natural = False, remove_low_authority = False)
        #print("cluster_list4", cluster_list4)

    #clusters_list = cluster_list1 + cluster_list2 # + cluster_list3 + cluster_list4      

    clusters_list, all_alternates_dict = first_clustering(G, betweenness_cutoff = 0.1, ignore_betweenness = False, apply_walktrap = True)

       
    clusters_list = [x for x in clusters_list if len(x) > 1]
    for x in clusters_list:
        logging.debug("First clusters".format(x))

    # Why is dedup_clusters only used once? 
    new_clusters_list = dedup_clusters(clusters_list, G, minclustsize)

    #for x in new_clusters_list:
      #print("Deduplicated first clusters", x)


    clusters_filt = []
    too_small = []
    for clust in new_clusters_list:
        #MINCLUSTSIZE CHANGE TO 0
          if len(clust) >= minclustsize:
                clusters_filt.append(clust)
          else:
             # This is ever happening?
             if len(clust) > 2:
                too_small.append(clust)
    for x in clusters_filt:
          logging.debug("First clusters with small removed {}".format(x))
    for x in too_small:
          logging.debug("collected as too_small {}".format(x))
    #for x in clusters_filt:
          #ic("First clusters with small removed", x)
    #for x in too_small:
          #ic("collected as too_small", x)
    #MINCLUSTSIZE CHANGE TO 0

    cluster_order, clustid_to_clust, pos_to_clustid, alignment = organize_clusters(clusters_filt, seqs_aas, gapfilling_attempt = 0, minclustsize = minclustsize, all_alternates_dict = all_alternates_dict, seqnames = seq_names, args = args)

    clusters_filt = list(clustid_to_clust.values())  

    alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names)

    #MINCLUSTSIZE CHANGE TO 0?
    #minclustsize=0
    if len(seqnums) > 2:
       if len(seqnums) < 5:
            minclustsize = len(seqnums) - 1
       else:
            minclustsize = 4
    else:
       minclustsize = 2

    ignore_betweenness = False
    minscore = 0.5
    betweenness_cutoff = 0.30
    history_unassigned = {'onebefore':[], 'twobefore':[], 'threebefore':[]}
    #print(alignment)
    most_complete_alignment = alignment
    too_small = [] 
    rbh_dict = {}
    match_dict = {}


    ############## CONTROL LOOP ###################
    for gapfilling_attempt in range(0, max_iterations):
        gapfilling_attempt = gapfilling_attempt + 1
        logging.debug("Main gapfilling_attempt {}".format(gapfilling_attempt))
        if gapfilling_attempt > 6 and minclustsize > 2 and gapfilling_attempt % 2 == 1:
                minclustsize = minclustsize - 1
                        #MINCLUSTSIZE CHANGE TO 0
                #minclustsize=0
        #ic("This is the minclustsize", minclustsize)
         
        logging.debug(f"too small: {too_small}\n")
        unassigned = get_unassigned_aas(seqs_aas, pos_to_clustid, too_small)
        #for x in unassigned:
           #ic("unassign", x)

        if len(unassigned) == 0:
            logging.debug("Alignment complete after gapfilling attempt {}".format(gapfilling_attempt - 1))
     
            alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names)
            return(alignment, index, hidden_states, index_to_aa)

        if ( gapfilling_attempt == max_iterations or unassigned == history_unassigned['threebefore'] or  unassigned == history_unassigned['twobefore'] ) and gapfilling_attempt > 10:
            if minscore > 0.1 and gapfilling_attempt != max_iterations:
                minscore = 0.1
                logging.debug("reducing minscore to {} at gapfilling attempt {}".format(minscore, gapfilling_attempt))
            ################ Final stage #############3
            else: 
                alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names)
                #ic("current {}, previous record {}".format(alignment.numassigned,most_complete_alignment.numassigned))
                if alignment.numassigned < most_complete_alignment.numassigned:
                    logging.debug("Replacing current alignment with previous more complete alignment")

                    alignment = most_complete_alignment
                else:
                    logging.debug("Currently alignment is the most complete so far")
                logging.debug("Placing remaining amino acids")
                cluster_order, clustid_to_clust, pos_to_clustid, alignment = fill_in_hopeless(unassigned, seqs_aas, cluster_order, clustid_to_clust, pos_to_clustid, index, hidden_states, gapfilling_attempt, args = args)
                unassigned = get_unassigned_aas(seqs_aas, pos_to_clustid)
                #ic('This unassigned should be empty', unassigned)
                alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names)

                return(alignment,   index, hidden_states, index_to_aa)
 
        history_unassigned['threebefore'] = history_unassigned['twobefore']
        history_unassigned['twobefore'] = history_unassigned['onebefore']
        history_unassigned['onebefore'] = unassigned
        
        #CHANGE 
        apply_walktrap = False

        # Do one or two rounds of clustering between guideposts
        if gapfilling_attempt in list(range(1, max_iterations, 2)):#  or gapfilling_attempt in [1, 2, 3, 4]:
            logging.debug("Align by clustering within guideposts during gapfilling {}".format(gapfilling_attempt))
            # Don't allow modification of previous guideposts
            if gapfilling_attempt > 4:
                 apply_walktrap = True

            if gapfilling_attempt > 3:
              if gapfilling_attempt < 15: 
                # This removes 'stranded' amino acids, where neither the previous or next amino acid are placed adjacent. 
                # If a good cluster, will be added right back

               logging.debug("running address stranded")
               cluster_order, clustid_to_clust = address_stranded(alignment)
               alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names) 
               clusterlist = list(clustid_to_clust.values())
               new_clusterlist = []
               pos_to_clustid, clustid_to_clust = get_cluster_dict(clusterlist)
               unassigned = get_unassigned_aas(seqs_aas, pos_to_clustid)
            logging.debug("Align by clustering reciprocal best hit network")
            start_clustering = time()
            cluster_order, clustid_to_clust, pos_to_clustid, alignment, too_small, rbh_dict, all_new_G = fill_in_unassigned_w_clustering(unassigned, seqs_aas, cluster_order, clustid_to_clust, pos_to_clustid, I2,  gapfilling_attempt, minscore = minscore ,minclustsize = minclustsize, ignore_betweenness = ignore_betweenness, betweenness_cutoff = betweenness_cutoff, apply_walktrap = apply_walktrap, rbh_dict = rbh_dict, seqnames = seq_names, args = args)
            logging.debug("Time to cluster {}".format(time() - start_clustering))
            # Option here to write rbh network, but need to adapt to using graph object
            #outnet = "{}/{}.testnet_clustering_group_{}_gapfilling_{:04}.csv".format(record_dir, outfile_name, alignment_group, gapfilling_attempt)
            #with open(outnet, "w") as outfile:
            #      outfile.write("aa1,aa2,score\n")
            #      # If do reverse first, don't have to do second resort
            #      for x in all_new_rbh:
            #         outstring = "{},{},{}\n".format(x[0], x[1], x[2])        
            #         outfile.write(outstring)


        else:
            
            logging.debug("Align by best match (looser)")
            logging.debug("Add aa's to existing clusters")
            if gapfilling_attempt > 3:
              if gapfilling_attempt < 15: 
                # This removes 'stranded' amino acids. 
               alignment = make_alignment(cluster_order, seqnums, clustid_to_clust, seq_names) 
               clusterlist = list(clustid_to_clust.values())
               new_clusterlist = []
               pos_to_clustid, clustid_to_clust = get_cluster_dict(clusterlist)
               unassigned = get_unassigned_aas(seqs_aas, pos_to_clustid)
            logging.debug("starting looser")
            cluster_order, clustid_to_clust, pos_to_clustid, alignment, match_dict = fill_in_unassigned_w_search(unassigned, seqs_aas, cluster_order, clustid_to_clust, pos_to_clustid, index, hidden_states,  index_to_aa, gapfilling_attempt, minclustsize = minclustsize, remove_both = True, match_dict= match_dict, seqnames = seq_names, args = args, I2 = I2)
            logging.debug("looser complete")

            too_small = []


        # THINGS TO DO AT THE END OF GAPFILLING CYCLE
        if alignment.numassigned > most_complete_alignment.numassigned:
            logging.debug("This is the most complete alignment so far")
            most_complete_alignment = alignment
   
        outaln = "{}/{}.alignment_clustering_group_{}_gapfilling_{:04}.aln".format(record_dir, outfile_name, alignment_group, gapfilling_attempt)
        with open(outaln, "w") as outfile:
            outfile.write(alignment.format_aln("clustal"))

           
            #ic(alignment_ic(alignment, seq_names)[0])
    return( alignment,  index, hidden_states,  index_to_aa)   


#@profile
def tracefunc(frame, event, arg, indent=[0]):
    package_name = __name__.split('.')[0]

    if event == "call" and (package_name in str(frame)):
        indent[0] += 2
    return tracefunc

if __name__ == '__main__':

    

    import sys
    sys.settrace(tracefunc)


    args = get_align_args()

    fasta_path = args.fasta_path
    embedding_path = args.embedding_path
    outfile = args.out_path
    exclude = args.exclude
    fully_exclude = args.fully_exclude
    layers = args.layers
    heads = args.heads
    model_name = args.model_name
    emb_outfile = args.emb_outfile
    embed_only = args.embed_only
    sim_outfile = args.sim_outfile
    sim_only = args.sim_only
    pca_plot = args.pca_plot
    seqlimit = args.seqlimit
    headnorm = args.headnorm
    padding =args.padding
    seqsim_thresh  = args.seqsimthresh 
    do_clustering = args.do_clustering
    do_clustering = True

    #set do_clustering to False for pca tests
    loglevel = args.loglevel
    cpu_only = args.cpu_only
    # get place to store output file
    outfile_path =  os.path.dirname(outfile)     
    outfile_name =   os.path.splitext(os.path.basename(outfile))[0] # get outfile without extension
    rbh_outfile = args.rbh_outfile
    mis = args.mis
    mnc = args.mnc
    wmnc = args.wmnc
    rbh_only = args.rbh_only
    bhf = args.bhf

    if outfile_path:
       record_dir = "{}/alignment_files_{}".format(outfile_path,outfile_name)
    else:
       record_dir = "alignment_files_{}".format(outfile_name)
    if os.path.exists(record_dir):
        shutil.rmtree(record_dir, ignore_errors=True)
    
    os.mkdir(record_dir)
  
    if rbh_outfile:
        if os.path.exists(rbh_outfile):
            os.remove(rbh_outfile)


    args.outfile_name  = outfile_name
    args.record_dir = record_dir

    logname = "{}/{}.align.log".format(record_dir, outfile_name)
    log_format = "%(asctime)s::%(levelname)s::"\
             "%(filename)s::%(lineno)d::%(message)s"


    logging.basicConfig(filename=logname, level=loglevel, format=log_format)
    logging.debug("This message is from the main script")


    if heads is not None:
       with open(heads, "r") as f:
         headnames = f.readlines()
         headnames = [x.replace("\n", "") for x in headnames]

    else:
       headnames = None
    logging.debug("Check for torch")
    logging.debug(torch.cuda.is_available())

    minscore1 = 0.5

    logging.debug("model: {}".format(model_name))
    logging.debug("fasta: {}".format(fasta_path))
    logging.debug("padding: {}".format(padding))
    logging.debug("first score thresholds: {}".format(minscore1))
   
    seq_names, seqs, seqs_spaced = parse_fasta_for_embed(fasta_path, padding = padding)

    if seqlimit:
       seq_names = seq_names[0:seqlimit]
       seqs = seqs[0:seqlimit]
       seqs_spaced = seqs_spaced[0:seqlimit]

 

    # The first step is to get embeddings
    if embedding_path:
       with open(embedding_path, "rb") as f:
             embedding_dict = pickle.load(f)

    else:
        seqlens = [len(x) for x in seqs]
       #print("printing model_name:")
       #print(model_name)
       #print("printing seq_lens")
       #print(seqlens)
       #print("printing layers")
       #print(layers)
       #print("printing padding")
       #print(padding)
        #print("printing headnames")
        #print(headnames)
        embedding_dict = get_embeddings(seqs_spaced,
                                    model_name,
                                    seqlens = seqlens,
                                    get_sequence_embeddings = True,
                                    get_aa_embeddings = True,
                                    layers = layers,  
                                    padding = padding,
                                    heads = headnames,
                                    cpu_only = cpu_only, 
                                    half = False)
        if emb_outfile:
           with open(emb_outfile, "wb") as fOut:
               pickle.dump(embedding_dict, fOut, protocol=pickle.HIGHEST_PROTOCOL)

        if embed_only:
            #print("Quiting after saving embedding at {}".format(emb_outfile))
            exit(0)

    # The sequences are divided into groups
    #ic(seqsim_thresh)
    cluster_seqnums_list, cluster_seqs_list,  cluster_names_list, cluster_hstates_list, to_exclude = get_seq_groups(seqs ,seq_names, embedding_dict,  exclude, do_clustering, outfile_name, record_dir, seqsim_thresh = seqsim_thresh, pca_plot = pca_plot )

    aln_fasta_list = []
    excluded_records = []
    for excluded_seqnum in to_exclude:
         
         excluded_record = SeqRecord(Seq(seqs[excluded_seqnum]), id=seq_names[excluded_seqnum], description = '')
         excluded_records.append(excluded_record)
         # Option to keep poor matches out
         if fully_exclude != True:
            aln_fasta_list.append([">{}\n{}\n".format(seq_names[excluded_seqnum], seqs[excluded_seqnum])])
   

    with open("{}/{}.excluded.fasta".format(record_dir, outfile_name), "w") as output_handle:
        SeqIO.write(excluded_records, output_handle, "fasta")

    alignments = []
    hidden_states_list = []
    index_to_aas_list = []

    # For each sequence group, do a sub alignment
    for i in range(len(cluster_names_list)):
        group_seqs = cluster_seqs_list[i]

             

        group_seqnums = cluster_seqnums_list[i]
        group_names = cluster_names_list[i]
        group_embeddings = cluster_hstates_list[i] 
        #ic("group seqnames", group_names, group_seqnums)

        group_seqs_out = "{}/{}.alignment_group{}.fasta".format(record_dir, outfile_name, i)
        group_records = []

        for j in range(len(group_seqs)):
             group_records.append(SeqRecord(Seq(group_seqs[j]), id=group_names[j], description = ''))
 
        with open(group_seqs_out, "w") as output_handle:
            SeqIO.write(group_records, output_handle, "fasta")

        if len(group_names) ==  1:
             aln_fasta_list.append([">{}\n{}\n".format(group_names[0], group_seqs[0])])


        else:
            # Main function
            alignment, index, hidden_states, index_to_aa = do_msa(group_seqs, group_names, group_seqnums, group_embeddings, minscore1 = minscore1, alignment_group = i, args = args)
            alignments.append(alignment)
            index_to_aas_list.append(index_to_aa)
            hidden_states_list.append(hidden_states)
    
            cluster_order, clustid_to_clust = clusts_from_alignment(alignment)
            squish = True
            if squish == True: 
                logging.debug("attempt squish, merging adjacent non-overlapping clusters")       
                for rep in range(0,20):
                     prevclust = alignment
                     cluster_order, clustid_to_clust = squish_clusters(alignment, index, hidden_states, index_to_aa)                
                     alignment = make_alignment(cluster_order, group_seqnums, clustid_to_clust, group_names)
                     if alignment.alignment == prevclust:
                            break
            alignment = make_alignment(cluster_order, group_seqnums, clustid_to_clust, group_names)
             
            
            if pca_plot: 
                png_align_out = "{}/{}.alignment_group{}.fasta.png".format(record_dir, outfile_name,  i)
                do_pca_plot(hidden_states, index_to_aa, png_align_out, clustid_to_clust = clustid_to_clust)
     
          
            aln_fasta_list_group = []
    
            fasta_align_i = alignment.format_aln("fasta")
            clustal_align_i = alignment.format_aln("clustal")   #alignment_ic(alignment, group_names)
    
            aln_fasta_list.append([">" + x for x in fasta_align_i.split(">")][1:]) # fasta_align_i.split("\n")) 
    
    
            fasta_align_out = "{}/{}.alignment_group{}.fasta.aln".format(record_dir, outfile_name, i)
            with open(fasta_align_out, "w") as o:
                  o.write(fasta_align_i)
    
            clustal_align_out = "{}/{}.alignment_group{}.clustal.aln".format(record_dir, outfile_name, i)
            with open(clustal_align_out, "w") as o:
                  o.write(clustal_align_i)

            # If nothing to merge
            if len(cluster_names_list) == 1 and (( len(excluded_records) == 0 or fully_exclude == True )) :
                with open(outfile, "w") as o:
                      o.write(clustal_align_i)
                sys.exit()
           
         

            
    
   
    consolidator = "mafft"
    if consolidator == "mafft":
      if len(cluster_names_list) > 1 or ( len(excluded_records) > 0 or fully_exclude == False ) :
    
        seq_count = 1
    
        


        #ic("aln_fasta_list", aln_fasta_list)
        with open("{}/{}.all_fastas_aln.fasta".format(record_dir, outfile_name,), "w") as o:
    
            with open("{}/{}.key_table.txt".format(record_dir, outfile_name), "w") as tb:
                for k in range(len(aln_fasta_list)):
                  
                   for s in range(len(aln_fasta_list[k])):
                        o.write("{}\n".format(aln_fasta_list[k][s]))
                        tb.write("{} ".format(seq_count))
                        seq_count = seq_count + 1
                   tb.write("\n")
        

        try:
            
            os.system("mafft --clustalout --merge {}/{}.key_table.txt --auto {}/{}.all_fastas_aln.fasta > {}".format(record_dir, outfile_name, record_dir, outfile_name, outfile))
  
            os.system("cat {}".format(outfile))
        except Exception as E:
            ic("Not doing mafft merge") 
    
    #for i in range(len(alignments)):
    #                mylist = []
    #                cluster_order, clustid_to_clust = clusts_from_alignment(alignments[i])
    #                for key, value in clustid_to_clust.items():
    #                      clustid_embeddings = []
    #                      indexes = [x.index for x in value]
    #                      clustid_embeddings = np.take(hidden_states_list[i], indexes, 0)
    #                      clustid_embeddings = normalize(clustid_embeddings, axis =1, norm = "l2")
    #                      if len(indexes) > 1:
    #                          cosim = cosine_similarity(clustid_embeddings)
    #                          upper = cosim[np.triu_indices(cosim.shape[0], k = 1)]
    #                          #ic(upper)
    #                          mean_cosim = np.mean(upper)
    #                      else:
    #                         mean_cosim = 0
    #                      #ic(key, mean_cosim, len(indexes))






#@profile
#def consolidate_w_clustering(clusters_dict, seqs_aas_dict):
   # return(0) 
