#!/usr/bin/env python

import pandas as pd
import numpy as np
import os
import argparse
import sys

from Bio import SeqIO
from Bio import Align
from Bio.Seq import Seq
from itertools import groupby
import bisect
from collections import defaultdict

import umap
import copy

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection, LineCollection, QuadMesh

from random import random
import scipy.cluster.hierarchy as shc
from scipy.spatial.distance import squareform

import Levenshtein
import binascii
import pylibsais

from mscope.util import MotifPlacement
from mscope.motif_plot import MotifPlot
from mscope.multiseq_align import *


class Config:
    def __init__(self):
        self.dpi = 100

        self.edgecolor = 'black'
        self.linewidth = 0.25
        self.singlebase_edges=True

        self.cbar_fontsize = 20
        self.heatmap_labels_xfontsize = 20
        self.heatmap_labels_yfontsize = 20
        self.title_fontsize = 30

        self.args = None

cfg = Config()


def parse_fasta(filename):
    """Parse a FASTA file and return a dictionary of sequences"""
    sequences = {}
    for record in SeqIO.parse(filename, 'fasta'):
        if record.seq != "":
            sequences[record.description] = str(record.seq)
    return sequences

def get_reverse_complements(sequences):
    """Return a dictionary of reverse complement sequences"""
    reverse_complements = {}
    for description, seq in sequences.items():
        reverse_seq = Seq(seq).reverse_complement()
        reverse_complements[description] = str(reverse_seq)
    return reverse_complements

def prepare_suffix_string(sequence_dict):
    """Prepare a sequence dictionary in format required by pylibsais."""
    keys = []
    values = []
    index = []
    pos = 0
    for k,v in sequence_dict.items():
        keys.append(k)
        values.append(v)
        pos += len(v) + 1
        index.append(pos)
    index = np.array(index)
    seq = '$'.join(values) + '$'
    return (seq, index, keys) 

def get_positions(suffix_ar, kmer_idx, kmer_cnt):
    """Get all (sorted) positions of a kmer in a sequence"""
    positions = suffix_ar[kmer_idx:kmer_idx+kmer_cnt]
    return np.sort(positions)

def get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len):
    """Get the sequence of a kmer"""
    #get location of one of the kmer copies in original sequencedd
    xloc = suffix_ar[kmer_idx]
    #get the kmer sequence
    return seq[xloc:xloc+kmer_len]

def pairwise_alignment(seq1, seq2):
    
    aligner = Align.PairwiseAligner()
    aligner.mode = 'global'
    aligner.match_score = 0
    aligner.mismatch_score = -2
    aligner.open_gap_score = -4
    aligner.extend_gap_score = -1
    aligner.target_end_gap_score = -3
    aligner.query_end_gap_score = -3
    alignments = aligner.align(seq1, seq2)
    scores = []
    for alignment in alignments:
        scores.append(alignment.score)

    return max(scores)

def select_best_perm_kmer(new_kmer, kmer_list, previous_kmer):
    perm_kmer = [new_kmer[x:] + new_kmer[:x] for x in range(len(new_kmer))]
    all_kmer_list = [kmer['kmer'] for kmer in kmer_list]
    perm_kmer_exist = []
    for kmer in perm_kmer:
        if kmer in all_kmer_list:
            perm_kmer_exist += [kmer]
    perm_kmer_score = list(map(lambda i: sum([pairwise_alignment(a, i) for a in previous_kmer]), perm_kmer_exist))
    kmer_idx = perm_kmer_score.index(max(perm_kmer_score))
    return perm_kmer_exist[kmer_idx]

def select_best_kmers(k_min, k_max, seq, index, used_kmer, round, min_count=2, min_indiv_count=2, min_consecutive_count=2, min_consecutive_bp=6):
    """Select k-mers based on the amount of sequence masked.

    :param k_min: the minimum k-mer length
    :param k_max: the maximum k-mer length
    :param seq: the sequence to search, as prepared by prepare_suffix_string
    :param index: the index of the end of each sequence in seq, as prepared by prepare_suffix_string
    :param min_count: the minimum number of times a k-mer should occur in the full combined sequence (including overlaps)
    :param min_indiv_count: the minimum number of times a k-mer should occur in a single sequence (excluding overlaps)
    :param min_consecutive_count: the minimum number of consecutive times a k-mer should occur in a single sequence
    :param min_consecutive_bp: the minimum number of consecutive bases that need to be covered by the k-mer
    """

    #create suffix and LCP array
    suffix_ar, lcp_ar = pylibsais.sais(seq)
    #determine maximum length of valid suffixes at each position (should stop at $ and # symbol)
    mkmer_ar = pylibsais.max_suffix(seq)
    

    #get all kmers with min_count or more copies
    #returns no repetitive k-mers, but does return overlapping k-mers
    kmers = list(pylibsais.kmer_count(seq, suffix_ar, lcp_ar, mkmer_ar, k_min, k_max, min_count))
    kmers.sort(key=lambda x: (x[0] * x[2]), reverse=True) #sort on length * count, so that kmers that mask longer sequences are first

    kmers_with_min = []
    res = []
    max_continuous_masked_bp = 0
    evaluated = 0
    #walk across possible kmers
    for kmer_len, kmer_idx, kmer_cnt in kmers:
        #stop if we cannot improve on the current best
        if (kmer_cnt * kmer_len) < max_continuous_masked_bp:
            break

        
        #determine how much of the sequence is masked by this kmer
        total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, kmer_len, kmer_idx, kmer_cnt)
        evaluated += 1

        #do not report kmer if it is worse than the current best
        #apply filter constraints (see function parameters)
        if max_consecutive_count * kmer_len < max_continuous_masked_bp or \
            max_indiv_seq_count < min_indiv_count or \
            max_consecutive_count < min_consecutive_count or \
            max_consecutive_count * kmer_len < min_consecutive_bp:
            continue

        kmers_with_min += [{'kmer_len': kmer_len, 'kmer_idx': kmer_idx, 'kmer_cnt':kmer_cnt, 'kmer':get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len), 'min_kmer': pylibsais.min_string(get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len))}]

        max_continuous_masked_bp = max_consecutive_count * kmer_len

        #get the kmer sequence
        kmer_s = get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len)
        min_kmer = pylibsais.min_string(kmer_s)
        
        #get all positions of the kmer in 'seq' (can be overlapping)
        positions = get_positions(suffix_ar, kmer_idx, kmer_cnt)

        res.append({'kmer':kmer_s, 'min_kmer': min_kmer, 'suffix_cnt': kmer_cnt, 'total_masked': total_masked, 
                        'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * kmer_len, 'pos':positions, 'idx':kmer_idx})
    
    res.sort(key=lambda x: (x['max_consecutive_masked'], x['max_indiv_seq_count'], x['total_masked'], len(x['kmer']), x['kmer']), reverse=True)

    
    print(f'KMER EVALUATED: {evaluated}')
    print(f'KMER SELECTED: {len(res)}')

    if len(res) == 0:
        return ({}, suffix_ar, mkmer_ar, used_kmer)

    else:

        if round == 0:
            #sort kmers on priority: max continuous masked, then max count in individual sequence, then total masked, then length, then alphabetically
            candidate_kmer = res[0]


        else:
            selected_kmer_object = {}
            kmer_set_selected = res[0]['min_kmer']
            min_kmer = res[0]['min_kmer']
            kmer_selected = select_best_perm_kmer(kmer_set_selected, kmers_with_min, used_kmer)
            if seq.index(kmer_selected*2):
                kmer_set = []
                for kmer in kmers_with_min:
                    if kmer['min_kmer'] == kmer_set_selected:
                        kmer_set += [kmer]
                
                for kmer in kmer_set:
                    if kmer['kmer'] == kmer_selected:
                        selected_kmer_object = kmer

                positions = get_positions(suffix_ar, selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, selected_kmer_object['kmer_len'], selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                candidate_kmer = {'kmer':kmer_selected, 'min_kmer': min_kmer, 'suffix_cnt': selected_kmer_object['kmer_cnt'], 'total_masked': total_masked, 
                                'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * selected_kmer_object['kmer_len'], 'pos':positions, 'idx':selected_kmer_object['kmer_idx']}
            else:
                candidate_kmer = res[0]
        used_kmer += [candidate_kmer['kmer']]

        return (candidate_kmer, suffix_ar, mkmer_ar, used_kmer)





def select_all_kmer(seq, index, mink, maxk):

    #kmers that are selected
    selected_kmers = []

    #positions that are masked (list of tuples of position and kmer)
    marked_positions = []

    used_kmer = []
    n = 0
    #repeat until no (consequtive) kmers are found
    while True:    
        res, sa, mask, used_kmer = select_best_kmers(mink, maxk, seq, index, used_kmer, n)
        
        if res == {}:
            break
        
        selected = res
        selected_kmers.append(selected) 

        print(f"SELECT KMER: {selected['kmer']}")
        for k,v in selected.items():
            if k != 'kmer':
                print(f"- {k}: {v}")
        
        print('MASKED:')
        rseq, rmarked_pos = pylibsais.kmer_mask(seq, sa, mask, len(selected['kmer']), selected['idx'], selected['suffix_cnt'], 2, '.')
        print(rseq)
        print('\n' * 2)
        if(rseq.count('.') == 0):
            kmer = selected['kmer'] * 2
            #kmer = selected['kmer']
            idx = rseq.index(kmer)
            raise RuntimeError('No masked positions found')
        #mask sequence with # symbol. The '2' indicates that only stretches of at least 2 consecutive kmers are masked.
        seq, marked_pos = pylibsais.kmer_mask(seq, sa, mask, len(selected['kmer']), selected['idx'], selected['suffix_cnt'], 2, '#')
        marked_positions.extend([(e, selected['kmer']) for e in marked_pos])
        n += 1

    
    for selected in selected_kmers:
        seq, marked_pos = pylibsais.kmer_mask_simple(seq, selected['kmer'], '#')
        marked_positions.extend([(e, selected['kmer']) for e in marked_pos])

    return selected_kmers, marked_positions, seq


def select_best_kmers_motif_guided(k_min, k_max, seq, index, used_kmer, round, ref_motifs_dict_r, min_count=2, min_indiv_count=2, min_consecutive_count=2, min_consecutive_bp=6):
    """Select k-mers based on the amount of sequence masked.

    :param k_min: the minimum k-mer length
    :param k_max: the maximum k-mer length
    :param seq: the sequence to search, as prepared by prepare_suffix_string
    :param index: the index of the end of each sequence in seq, as prepared by prepare_suffix_string
    :param min_count: the minimum number of times a k-mer should occur in the full combined sequence (including overlaps)
    :param min_indiv_count: the minimum number of times a k-mer should occur in a single sequence (excluding overlaps)
    :param min_consecutive_count: the minimum number of consecutive times a k-mer should occur in a single sequence
    :param min_consecutive_bp: the minimum number of consecutive bases that need to be covered by the k-mer
    """

    #create suffix and LCP array
    suffix_ar, lcp_ar = pylibsais.sais(seq)
    #determine maximum length of valid suffixes at each position (should stop at $ and # symbol)
    mkmer_ar = pylibsais.max_suffix(seq)
    

    #get all kmers with min_count or more copies
    #returns no repetitive k-mers, but does return overlapping k-mers
    kmers = list(pylibsais.kmer_count(seq, suffix_ar, lcp_ar, mkmer_ar, k_min, k_max, min_count))
    kmers.sort(key=lambda x: (x[0] * x[2]), reverse=True) #sort on length * count, so that kmers that mask longer sequences are first


    kmers_with_min = []
    res = []
    max_continuous_masked_bp = 0
    evaluated = 0
    #walk across possible kmers
    for kmer_len, kmer_idx, kmer_cnt in kmers:
        kmers_with_min += [{'kmer_len': kmer_len, 'kmer_idx': kmer_idx, 'kmer_cnt':kmer_cnt, 'kmer':get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len), 'min_kmer': pylibsais.min_string(get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len))}]

    for kmer_len, kmer_idx, kmer_cnt in kmers:
        #stop if we cannot improve on the current best
        if (kmer_cnt * kmer_len) < max_continuous_masked_bp:
            break

        
        #determine how much of the sequence is masked by this kmer
        total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, kmer_len, kmer_idx, kmer_cnt)
        evaluated += 1

        #do not report kmer if it is worse than the current best
        #apply filter constraints (see function parameters)
        if max_consecutive_count * kmer_len < max_continuous_masked_bp or \
            max_indiv_seq_count < min_indiv_count or \
            max_consecutive_count < min_consecutive_count or \
            max_consecutive_count * kmer_len < min_consecutive_bp:
            continue

        max_continuous_masked_bp = max_consecutive_count * kmer_len

        #get the kmer sequence
        kmer_s = get_kmer_sequence(seq, suffix_ar, kmer_idx, kmer_len)
        min_kmer = pylibsais.min_string(kmer_s)
        
        #get all positions of the kmer in 'seq' (can be overlapping)
        positions = get_positions(suffix_ar, kmer_idx, kmer_cnt)

        res.append({'kmer':kmer_s, 'min_kmer': min_kmer, 'suffix_cnt': kmer_cnt, 'total_masked': total_masked, 
                        'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * kmer_len, 'pos':positions, 'idx':kmer_idx})
    
    res.sort(key=lambda x: (x['max_consecutive_masked'], x['max_indiv_seq_count'], x['total_masked'], len(x['kmer']), x['kmer']), reverse=True)
    
    print(f'KMER EVALUATED: {evaluated}')
    print(f'KMER SELECTED: {len(res)}')

    if len(res) == 0:
        return ({}, suffix_ar, mkmer_ar, used_kmer)

    else:
        selected_kmer_object = {}
        kmer_set_selected = res[0]['min_kmer']
        min_kmer = res[0]['min_kmer']
        kmer_set = []
        if round == 0:
            if kmer_set_selected in ref_motifs_dict_r:
                for kmer in kmers_with_min: 
                    if kmer["kmer"] == ref_motifs_dict_r[kmer_set_selected]:
                        selected_kmer_object = kmer
                positions = get_positions(suffix_ar, selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, selected_kmer_object['kmer_len'], selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                candidate_kmer = {'kmer':ref_motifs_dict_r[kmer_set_selected], 'min_kmer': min_kmer, 'suffix_cnt': selected_kmer_object['kmer_cnt'], 'total_masked': total_masked, 
                                'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * selected_kmer_object['kmer_len'], 'pos':positions, 'idx':selected_kmer_object['kmer_idx']}
                
            else:
                candidate_kmer = res[0]

        else:
            if kmer_set_selected in ref_motifs_dict_r:
                kmer_selected = ref_motifs_dict_r[kmer_set_selected]
                for kmer in kmers_with_min:
                    if kmer['min_kmer'] == kmer_set_selected:
                        kmer_set += [kmer]
                
                for kmer in kmer_set:
                    if kmer['kmer'] == kmer_selected:
                        selected_kmer_object = kmer
                
                if selected_kmer_object != {}:

                    print(kmer_set)
                    positions = get_positions(suffix_ar, selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                    total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, selected_kmer_object['kmer_len'], selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                    candidate_kmer = {'kmer':kmer_selected, 'min_kmer': min_kmer, 'suffix_cnt': selected_kmer_object['kmer_cnt'], 'total_masked': total_masked, 
                                    'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * selected_kmer_object['kmer_len'], 'pos':positions, 'idx':selected_kmer_object['kmer_idx']}
                else:
                    candidate_kmer = {}
                    
            else:
                kmer_selected = select_best_perm_kmer(kmer_set_selected, kmers_with_min, used_kmer)

                if kmer_selected*2 in seq:
                    kmer_set = []
                    for kmer in kmers_with_min:
                        if kmer['min_kmer'] == kmer_set_selected:
                            kmer_set += [kmer]
                    
                    for kmer in kmer_set:
                        if kmer['kmer'] == kmer_selected:
                            selected_kmer_object = kmer

                    positions = get_positions(suffix_ar, selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                    total_masked, max_indiv_seq_count, max_consecutive_count = pylibsais.kmer_mask_potential(suffix_ar, mkmer_ar, index, selected_kmer_object['kmer_len'], selected_kmer_object['kmer_idx'], selected_kmer_object['kmer_cnt'])
                    candidate_kmer = {'kmer':kmer_selected, 'min_kmer': min_kmer, 'suffix_cnt': selected_kmer_object['kmer_cnt'], 'total_masked': total_masked, 
                                    'max_indiv_seq_count':max_indiv_seq_count, 'max_consecutive_masked':max_consecutive_count * selected_kmer_object['kmer_len'], 'pos':positions, 'idx':selected_kmer_object['kmer_idx']}
                else:
                    candidate_kmer = res[0]
        if candidate_kmer != {}:
            used_kmer += [candidate_kmer['kmer']]

        return (candidate_kmer, suffix_ar, mkmer_ar, used_kmer)



def select_all_kmer_motif_guided(seq, index, mink, maxk, ref_motifs_list):
    ref_motifs_dict = {}
    for motif in ref_motifs_list:
        ref_motifs_dict[motif] = min([motif[x:] + motif[:x] for x in range(len(motif))])
    ref_motifs_dict_r = {value: key for key, value in ref_motifs_dict.items()}

    #kmers that are selected
    selected_kmers = []

    #positions that are masked (list of tuples of position and kmer)
    marked_positions = []

    used_kmer = []
    n = 0
    #repeat until no (consequtive) kmers are found
    while True:    
        res, sa, mask, used_kmer = select_best_kmers_motif_guided(mink, maxk, seq, index, used_kmer, n, ref_motifs_dict_r)
        
        if res == {}:
            break
        
        selected = res
        selected_kmers.append(selected) 

        print(f"SELECT KMER: {selected['kmer']}")
        for k,v in selected.items():
            if k != 'kmer':
                print(f"- {k}: {v}")
        
        print('MASKED:')
        rseq, rmarked_pos = pylibsais.kmer_mask(seq, sa, mask, len(selected['kmer']), selected['idx'], selected['suffix_cnt'], 2, '.')
        print(rseq)
        print('\n' * 2)

        if(rseq.count('.') == 0):
            kmer = selected['kmer']
            if rseq.index(kmer):
                seq, marked_pos = pylibsais.kmer_mask_simple(seq, selected['kmer'], '#')
                marked_positions.extend([(e, selected['kmer']) for e in marked_pos])
                n += 1
                continue
            else:
                raise RuntimeError('No masked positions found')
        #mask sequence with # symbol. The '2' indicates that only stretches of at least 2 consecutive kmers are masked.
        seq, marked_pos = pylibsais.kmer_mask(seq, sa, mask, len(selected['kmer']), selected['idx'], selected['suffix_cnt'], 2, '#')
        marked_positions.extend([(e, selected['kmer']) for e in marked_pos])
        n += 1

    for ref in ref_motifs_list:
        seq, marked_pos = pylibsais.kmer_mask_simple(seq, ref, '#')
        if marked_pos != []:
            print(ref)
            marked_positions.extend([(e, ref) for e in marked_pos])
            if selected_kmers not in selected_kmers:
                selected_kmers += [{'kmer': ref}]


    for selected in selected_kmers:
        #mask sequence with # symbol
        seq, marked_pos = pylibsais.kmer_mask_simple(seq, selected['kmer'], '#')
        marked_positions.extend([(e, selected['kmer']) for e in marked_pos])
           
    return selected_kmers, marked_positions, seq

def mask_all_seq(selected_kmers, marked_positions, seq):
    for s in selected_kmers:
        print(s['kmer'])

    marked_positions.sort(key=lambda x:x[0])
    return marked_positions


def get_grouped_positions(seq_last_index, seq_name, motif_positions, seq_concat, add_single_base=True):
    """Group the positions of the motifs by sequence name.
    For each sequence, gives a list containing tuples of the form (start, end, motif, count).

    :param seq_last_index: list of the last index of each sequence in the concatenated sequence
    :param seq_name: list of the names of the sequences
    :param motif_positions: list of tuples of the form (position, motif). Sorted by position.

    :return: a dictionary with the sequence name as key and a list of tuples as value of the form (start, end, motif, count)
             a list of unique motifs
    """
    assert len(seq_last_index) == len(seq_name), "Length of seq_last_index and seq_name should be the same"
    assert len(seq_name) > 0, "At least one sequence should be provided"

    sample_positions = {}
    cur_sample_pos = 0
    current_motifs = []
    cur_seq_start = 0
    prev_pos = 0
    unique_motifs = set()
    for pos, motif in motif_positions:
        assert pos >= prev_pos, f"Positions should be in increasing order. Found {pos} after {prev_pos}"
        end_pos = pos + len(motif)

        while end_pos > seq_last_index[cur_sample_pos]:
            #add end non-motif sequence
            if add_single_base and prev_pos < seq_last_index[cur_sample_pos]:
                xseq = seq_concat[prev_pos:(seq_last_index[cur_sample_pos] - 1)]
                current_motifs.append(MotifPlacement(xseq, prev_pos - cur_seq_start, singlebase=True))
                #current_motifs.append((prev_pos - cur_seq_start, seq_last_index[cur_sample_pos] - cur_seq_start - 1, xseq, 0))

            sample_positions[seq_name[cur_sample_pos]] = current_motifs
            prev_pos = cur_seq_start = seq_last_index[cur_sample_pos]
            cur_sample_pos += 1
            current_motifs = []
            assert cur_sample_pos < len(seq_last_index), f"Position {end_pos} not found in {seq_name}"


        #add intermediate non-motif sequence
        if add_single_base and prev_pos < pos:
            xseq = seq_concat[prev_pos:pos]
            current_motifs.append(MotifPlacement(xseq, prev_pos - cur_seq_start, singlebase=True))
        
        assert pos - cur_seq_start >= 0, f"Position {pos} not found in {seq_name[cur_sample_pos]}"
        if not (current_motifs and  current_motifs[-1].attempt_add(motif)):
            current_motifs.append(MotifPlacement(motif, pos - cur_seq_start))
            unique_motifs.add(motif)

        prev_pos = end_pos

    #add end non-motif sequence
    if add_single_base and prev_pos < seq_last_index[cur_sample_pos]:
        xseq = seq_concat[prev_pos:(seq_last_index[cur_sample_pos] - 1)]
        current_motifs.append(MotifPlacement(xseq, prev_pos - cur_seq_start, singlebase=True))

    sample_positions[seq_name[cur_sample_pos]] = current_motifs
    cur_sample_pos += 1

    #deal with sequences without motifs
    while cur_sample_pos < len(seq_last_index):
        if add_single_base:
            xseq = seq_concat[seq_last_index[cur_sample_pos - 1]:(seq_last_index[cur_sample_pos] - 1)]
            sample_positions[seq_name[cur_sample_pos]] = [MotifPlacement(xseq, 0, singlebase=True)]
        cur_sample_pos += 1
    return sample_positions, list(unique_motifs)


def get_label(value,range_starts,labels):
    idx = bisect.bisect_right(range_starts, value) - 1
    return labels[idx], range_starts[idx]

def get_all_positions(index, marked_positions, seq_dict, seq_concat):
    all_marked_positions = set(range(len(seq_concat)))
    marked = []
    for t in marked_positions:
        marked += list(range(t[0], t[0] + len(t[1])))
    marked = set(marked)
    single_bp = all_marked_positions - marked

    all_motifs = []
    for t in single_bp:
        all_motifs += [(t, t+1, seq_concat[t])]
    
    for t in marked_positions:
        all_motifs += [(t[0], t[0]+len(t[1]), t[1])]
    all_motifs = sorted(all_motifs, key=lambda x: x[0])

    starts = np.append(index, [0])
    starts.sort()
    starts = starts[:-1]
    ids = list(seq_dict.keys())
    if not all(starts[i] <= starts[i + 1] for i in range(len(starts) - 1)):
        combined = sorted(zip(starts, ids))
        starts, ids = zip(*combined)

    # Create the labeled values list using a list comprehension
    all_motif_seq = [
    (m[0] - start, m[1] - start, m[2], label) 
    for m in all_motifs 
    if m[2] != "$"
    for label, start in [get_label(m[0], starts, ids)]
    ]

    return all_motif_seq

def write_compressed_seq(grouped_positions, output):
    compressed_fasta_file = output + "_compressed_representation" + ".fa"
    with open (compressed_fasta_file, 'w') as compressed_fasta_file:
        for seq in grouped_positions:
            compressed_fasta_file.write(">" + seq + "\n")
            #compressed_fasta_file.write(" ".join([f"{ma.motif}{ma.count}" if not ma.singlebase else ma.motif for ma in grouped_positions[seq]]) + "\n")
            #compressed_fasta_file.write(" ".join([f"{ma.motif}{ma.count}" if not ma.singlebase else for ma in grouped_positions[seq]]) + "\n")
            compressed_fasta_file.write(" ".join([f"{ma.motif}{ma.count}" if not ma.singlebase else ' '.join([f"{m}1" for m in ma.motif]) for ma in grouped_positions[seq]]) + "\n")

def write_summary(all_motif_seq, output):
    motif_counts = defaultdict(lambda: defaultdict(int))
    global_length_count = defaultdict(int)
    global_total_count = defaultdict(int)
    seq_order = [] 
    all_motifs = set()  

    for start, end, motif, seq in all_motif_seq:
        motif_counts[seq][motif] += 1
        global_length_count[motif] += len(motif)
        global_total_count[motif] += 1
        all_motifs.add(motif)
        if seq not in seq_order:
            seq_order.append(seq)

    summary = []

    for seq in seq_order:
        motifs = motif_counts[seq]
        seq_summary = []
        for motif, count in motifs.items():
            length = len(motif) * count
            seq_summary.append((seq, motif, length, count))

        seq_summary.sort(key=lambda x: x[2], reverse=True)
        summary.extend(seq_summary)

    global_summary = []
    for motif in all_motifs:
        total_length = global_length_count[motif]
        total_count = global_total_count[motif]
        global_summary.append(('all', motif, total_length, total_count))

    global_summary.sort(key=lambda x: x[2], reverse=True)

    motif_summary_file = output + "_motif_summary.txt"
    with open(motif_summary_file, 'w') as f:
        #f.write("seq\tmotif\tlength\tcount\n")
        for row in summary:
            f.write('\t'.join(map(str, row)) + '\n')
        for row in global_summary:
            f.write('\t'.join(map(str, row)) + '\n')


def get_motif_pairwise_distance(motif_list):
    score_dict = {}
    # Compute scores for each pair of motifs
    for i in range(len(motif_list)):
        for j in range(i, len(motif_list)):
            motif_1 = motif_list[i]
            motif_2 = motif_list[j]
            score = pairwise_alignment(motif_1, motif_2)
            score_dict[(motif_1, motif_2)] = score
            if i != j:
                score_dict[(motif_2, motif_1)] = score

    score_df = pd.DataFrame(score_dict.values(), index=pd.MultiIndex.from_tuples(score_dict.keys())).unstack().astype(float)
    return score_df

def sort_motifs(all_seq_motifs):
    motif_counts = defaultdict(int)
    for m in all_seq_motifs:
        if len(m[2]) > 1:
            motif = m[2]
            motif_counts[motif] += 1
    motif_counts = dict(motif_counts)
    sorted_motifs = sorted(motif_counts.items(), key=lambda x: x[1], reverse=True)

    # Step 3: Assign ranks starting from the highest count
    max_rank = len(sorted_motifs)
    motif_rank_dict = {motif: max_rank - rank for rank, (motif, count) in enumerate(sorted_motifs)}

    # Step 4: Convert the motif_rank_dict to a DataFrame
    motif_rank_df = pd.DataFrame(list(motif_rank_dict.items()), columns=['motif', 'dimension_reduction'])

    # Step 5: Sort the DataFrame by Rank in descending order
    motif_rank_df = motif_rank_df.sort_values(by='dimension_reduction', ascending=True).reset_index(drop=True)

    return motif_rank_df


def run_umap(dm, method='UMAP', rank=0.5, norm=True):
    """Maps motif distance matrix 'dm' to a 1-dimensional space using UMAP or MDS.
       :method: string, either 'UMAP' or 'MDS'.
       :rank: float, between 0 and 1. Transforms between original embedding (0.0) and a fully ranked embedding (1.0, i.e. only the ordering of the motifs is preserved).
       :norm: bool, whether to normalize the distance matrix by the geometric mean of the sequence lengths.
    """
    if dm.shape[0] == 1:
        X_transform_L2 = pd.DataFrame(columns = ["dimension_reduction", "motif"])
        X_transform_L2["motif"] = dm[0].columns
        X_transform_L2["dimension_reduction"] = 0
    elif dm.shape[0] == 2:
        X_transform_L2 = pd.DataFrame(columns = ["dimension_reduction", "motif"])
        X_transform_L2["motif"] = dm[0].columns
        X_transform_L2["dimension_reduction"] = [0,1]

    else:
        n_neighbors = min(10,max(2,dm.shape[0]/2)) # 10 or half of the number of motifs, whichever is smaller, with a minimum of 5
        n_neighbors = int(min(n_neighbors,dm.shape[0])) # cannot be larger than the number of motifs

        data = (-dm).to_numpy(copy=True,dtype=np.float32)

        if norm:
            dnorm = [len(e) for e in dm.index]
            snorm = np.sqrt(dnorm)
            data = data / (snorm[:,np.newaxis] * snorm[np.newaxis,:])

    
        if method == 'UMAP':
            #UMAP has a lot of overhead for small N.
            #This is much faster than builtin method for finding nearest neighbors for small N
            if n_neighbors * 5 < dm.shape[0]: #top-k sort for very large motif sets
                idx = np.argpartition(data,np.arange(n_neighbors),axis=1)[:,:n_neighbors]
                dist = np.take_along_axis(data,idx,axis=1)
            else:
                idx = np.argsort(data,axis=1)[:,:n_neighbors]
                dist = np.take_along_axis(data,idx,axis=1)
            
            #NN-descent only needed for transform of new data (https://github.com/lmcinnes/umap/issues/848)
            import pynndescent
            class DummyNNDescent(pynndescent.NNDescent):
                def __init__(self):
                    return
            precomputed_knn = (idx,dist, DummyNNDescent())
        
            manifold_f = umap.UMAP(n_components = 1, metric = "precomputed", n_neighbors = n_neighbors, min_dist = 0.5, random_state = 0, precomputed_knn=precomputed_knn, force_approximation_algorithm=True)


        elif method=="MDS":
            from sklearn.manifold import MDS
            manifold_f = MDS(n_components=1, n_init=50, metric=False, dissimilarity='precomputed')
        else:
            raise RuntimeError(f'Unknown manifold method {method}. Choose UMAP or MDS.')
    
        result = manifold_f.fit_transform(data) 
        
        if rank:
            result = result.ravel()
            idx = np.argsort(result)
            df =np.diff(result[idx])
            r = np.max(result) - np.min(result)
            w = np.cumsum(df * (1 - rank) + rank * (r / (dm.shape[0] - 1)))
            result[idx[1:]] = w
            result[idx[0]] = 0.0

        X_transform_L2 = pd.DataFrame(result)
        X_transform_L2.columns = ["dimension_reduction"]
        X_transform_L2["motif"] = dm.index
        X_transform_L2 = X_transform_L2.astype({"dimension_reduction": float})
    
    
    return X_transform_L2

def map_grouped_score_to_alignment(grouped_motif_seq, X_transform_L2):
    #sort by dimension reduction
    X_transform_L2 = X_transform_L2.sort_values(by='dimension_reduction', ascending=True).reset_index(drop=True)
    

    motif_dict = {str(motif): float(pos) for pos, motif in enumerate(X_transform_L2['motif'])}
    motif_dict['nan'] = np.nan


    for seq, motifs in grouped_motif_seq.items():
        for motif in motifs:
            motif.set_score(motif_dict)


#######################


parser = argparse.ArgumentParser(description='MotifScope')


parser.add_argument('-i', '--input', default = None, dest='input_fasta_to_count', required = True, 
                    metavar = "input.fa", type = str,
                    help ='Input fasta file to analyze')

parser.add_argument('-reverse', dest='reverse', type = str, choices=['True', 'False'], default = "False",
                    help ='Use reverse complement of input fasta')

parser.add_argument('-mink', '--min_kmer', default = 2, dest = 'min_kmer_size',
                    type = int, metavar = 2,
                    help ='Minimum length of motif to characterize')

parser.add_argument('-maxk', '--max_kmer', default = 10, dest='max_kmer_size',
                    type=int, metavar = 10,
                    help ='Maximum length of motif to characterize')

parser.add_argument('-t', '--title', default = None, dest='title',
                    metavar = "title", type=str,
                    help ='Title of the plot, default: None')

parser.add_argument('-msa', '--msa', dest = 'run_msa', type = str, choices=['True', 'False', 'POAMotif', 'POANucleotide', 'Levenshtein'], default = "Levenshtein",
                    help = 'MSA setting. Either POAMotif, POANucleotide or Levenshtein. For backwards compatibility, True/False can be used, which maps to POANucleotide/Levenshtein. Default: Levenshtein.')

parser.add_argument('-motif', '--motif', dest = 'ref_motifs', type = str, required = False,
                    help = 'Use guided motif mode, using established motifs for analysis by taking a file with known motifs separated by tabs.', metavar = 'motifs.txt')

parser.add_argument('-p', '--sample_class', dest = 'sample_class', required = False, 
                    metavar = "metadata.txt", type=str,
                    help = 'Path to a tab separated file with a column for sequence names and a column for class labels. If provided, an extra column will be added to the output file with the class labels.')

parser.add_argument('-c', '--cluster', dest = 'cluster', type = str, choices=['True', 'False'], default = "True",
                    help = 'Cluster sequences. Boolean (True/False). Default: True')

parser.add_argument('-prof', '--profile', type = str, dest = 'profile', default = False, choices=['True', 'False'],
                     help = 'Enable profiling (stored in stats.txt)')

parser.add_argument('-e', '--embed_motif_method', default = 'UMAP', dest= 'embed_motif_method', choices=['UMAP', 'MDS', 'random'],
                     help = 'Embedding method for motif color scale (option: MDS, UMAP or random), default: UMAP')

parser.add_argument('-r', '--motif_rank_embed', default = 0.5, dest = 'motif_rank_embed', type = float,
                     help = 'Hold to original embedding (value=0.0) or only preserve order and place motifs equidistant on color map (value=1.0). Default: 0.5')

parser.add_argument('-format', '--format', default = 'pdf', dest = 'format', choices=['png', 'pdf'],
                    help = 'Image output format (png, pdf). Default: pdf')

parser.add_argument('-figure', '--figure', default = 'True', type = str, dest = 'figure', choices=['True', 'False'],
                    help = 'Outputs figure. Boolean (True/False). Default: True')

parser.add_argument('-o', '--output_prefix', type = str, dest = 'output_prefix', required = True, metavar = "ouput",
                    help = 'Prefix of output files.')

args, unknown = parser.parse_known_args()
if unknown:
    parser.error(f"Unknown arguments: {' '.join(unknown)}")


cfg.args = args

#args = parser.parse_args()

if args.run_msa == "True":
    args.run_msa = "POANucleotide"
elif args.run_msa == "False":
    args.run_msa = "Levenshtein"



title = args.title
input_fasta_to_count = args.input_fasta_to_count
max_kmer_size = args.max_kmer_size
min_kmer_size = args.min_kmer_size
run_msa = args.run_msa
motif_guided = args.ref_motifs is not None
ref_motifs = args.ref_motifs
output_prefix = args.output_prefix
figure = args.figure
reverse = args.reverse

if args.profile:
    #imports for profiling
    import cProfile, pstats
    from pstats import SortKey
    print('Profiling enabled.')
    pr = cProfile.Profile()
    pr.enable()

if os.path.isabs(output_prefix):
    prefix_dir, prefix = os.path.split(output_prefix)
else:
    prefix_dir = os.getcwd()
    prefix = output_prefix
if not os.path.exists(prefix_dir):
    print(f"Error: The directory '{prefix_dir}' does not exist.")
    sys.exit(1)
output_file = os.path.join(prefix_dir, f"{prefix}_result.txt")
output = os.path.join(prefix_dir, prefix)

cfg.output = output

all_seq_dict = parse_fasta(input_fasta_to_count)
if len(all_seq_dict) == 1 and args.cluster == "True" and figure == "True":
    print(f"Error: Cannot perform clustering for a single sequence.")
    sys.exit(1)

if reverse == "True":
    all_seq_dict = get_reverse_complements(all_seq_dict)

seq_concat, seq_index, seq_keys = prepare_suffix_string(all_seq_dict)

if not motif_guided:
    candidate_kmer, masked_postion, masked_seq = select_all_kmer(seq_concat, seq_index, min_kmer_size, max_kmer_size)
else:
    with open(ref_motifs, 'r') as ref:
        ref_motifs = ref.readlines()
    ref_motifs_list = [i.strip().split("\t") for i in ref_motifs]
    ref_motifs_list = [i for l in ref_motifs_list for i in l]
    candidate_kmer, masked_postion, masked_seq = select_all_kmer_motif_guided(seq_concat, seq_index, min_kmer_size, max_kmer_size, ref_motifs_list)

masked_postion = mask_all_seq(candidate_kmer, masked_postion, masked_seq)

grouped_positions, unique_motifs = get_grouped_positions(seq_index, seq_keys, masked_postion, seq_concat)
unique_motifs_and_sb = ['A', 'C', 'G', 'T'] + unique_motifs

#unique_motifs = {m[2] for m in all_positions}
#unique_motifs = list(unique_motifs)

all_positions = get_all_positions(seq_index, masked_postion, all_seq_dict, seq_concat)
write_compressed_seq(grouped_positions, output)
write_summary(all_positions, output)


if run_msa == "POANucleotide":
    malign = POAAlignerNucleotide(cfg)
elif run_msa == "POAMotif":
    malign = POAAlignerMotif(cfg)
else:
    malign = LevenshteinAligner(cfg, all_seq_dict)

grouped_positions, sequence_lengths, all_seq_distance_df = malign.run(grouped_positions, unique_motifs_and_sb)

print("tandem repeat decomposition DONE")
print(all_seq_distance_df)
if figure == "True":
    alignment_score_matrix = get_motif_pairwise_distance(unique_motifs)

    if args.embed_motif_method == "random":
        dimension_reduction_result = sort_motifs(all_positions)
    elif args.embed_motif_method in ["UMAP", "MDS"]:
        dimension_reduction_result = run_umap(alignment_score_matrix, method=args.embed_motif_method, rank=args.motif_rank_embed)   

    #sort dimension reduction result
    dimension_reduction_result = dimension_reduction_result.sort_values(by = "dimension_reduction")
        

    has_classes = args.sample_class is not None
    if has_classes:
        classes_file = args.sample_class
        #read first line to check if it is a header
        with open(classes_file, 'r') as f:
            first_line = f.readline()
            if first_line.startswith("sample\t"):
                header = 0
            else:
                header = None

        class_metadata = pd.read_csv(classes_file, sep = "\t", header = header)
        if header is None:
            class_metadata.columns = ["sample", "population"]


    map_grouped_score_to_alignment(grouped_positions, dimension_reduction_result)

        
    figfile = output + f".{args.format}"
    print("generating figure")
    figure = MotifPlot(cfg, title, figfile)

    figure.load_data(grouped_positions, sequence_lengths, dimension_reduction_result)
    if args.cluster == "True":
       figure.enable_dendrogram(all_seq_distance_df)
    if has_classes:
        figure.enable_classes(class_metadata)
    figure.create_figure()


if args.profile:
    pr.disable()
    print('Storing profiling results in stats.txt...')
    with open('stats.txt','w') as f:
        sortby = SortKey.CUMULATIVE
        ps = pstats.Stats(pr, stream=f).sort_stats(sortby)
        ps.print_stats()
        ps.print_callers()
        ps.print_callees()

print('DONE')


