# python3.7, Need scipy>=1.6.0
'''
Log:

Modication on Oct 7:
    change to automatically determine the size of motif seq
'''
'''
Use to cluster the significant motif seqs defined by identifyMotif.py

Logic:
Calculate the distance (1bp mismatch equals to distance=1) between each pair of the motif seq;
based on distance matrix, cluster the motif seqs;
determine the number of clusters (p_number) to cut the dengram tree;
report the motif seqs in each cluster and count the base composition at teach position,
the returing motif seqs can be used for motif logo calling using weblogo.

Usage:
$python clusterMotif.py --input Ecoli_SignificantMotif.txt --name Ecoli --number 5 > report.txt
The returning will be saved in report.txt

--input:
File generated by identifyMotify.py containing the significant motif seq, e.g.
Name    Sample  Reference       pvalue  corrected_alpha Significance
GCCCAGGT        157     285     0.0     6.105378838756945e-09   True
CACCAGGC        282     429     0.0     6.105378838756945e-09   True
...

--name: Name for the heatmap and dengram images and cluster seq files, Not Required
name_Motif_cluster_heatmap.png, this heatmap helps the determination of p_number
name_Motif_cluster_dengram.png, cut tree returning a dengram in which the number of clusters is determined by p_number
name_Motifcluster1, name_Motifcluster2 ... These files have the motif seq that can be used to call motif logo.

--number: p for hierarchy.dendrogram, determining the number of clusters generated by cut tree

Returning:
The number of motif seqs in the cluster1 is: 64
   Pos1  Pos2  Pos3  Pos4  Pos5  Pos6  Pos7  Pos8
A    16    16     0     0    64     0     0    16
T    16    16     0     0     0     0     0    16
C    16    16    64    64     0     0     0    16
G    16    16     0     0     0    64    64    16
number at each postiion is the number of seqs in the cluster.
So the above motif is NNCCAGGN.
The returning motif seqs can be used to call motif logo using weblogo.

'''

try:
    import argparse
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import sys
    import re
    import scipy.cluster.hierarchy as hierarchy # dendrogram, leaves_list, Need scipy>=1.6.0
except:
    print('Module Error')
    quit()

sys.setrecursionlimit(1500) # default setting may raise error: maximum recursion depth exceeded while getting the str of an object

class motif:
    def __init__(self, motif_file, prefix):
        self.motif_file = motif_file
        self.prefix = prefix # name prefix for heatmap and dendrogram figures

    def CountDistance(self, x1, x2):
        '''Return the number of different bases between two motif seq'''
        count = 0
        for i in range(0, len(x1)):
            if x1[i]!=x2[i]:
                count+=1
        return count

    def CountBase(self, ls_motif, size):
        '''
        Count the number of A,T,C,G at each position for motif in list
        size: 8bp motif size=8
        '''
        df = pd.DataFrame(0, index=['A', 'T', 'C', 'G'], columns=['Pos{}'.format(i) for i in range(1, size+1)])

        for i in range(1, size+1):
            for row in ['A', 'T', 'C', 'G']: 
                ls_base = [item[i-1] for item in ls_motif]
                count = len([item for item in ls_base if item==row])
                df.loc[(row, 'Pos{}'.format(i))] = count
        return df

    def CreateMotifDf(self):
        '''
        generate a dataframe and calculate the distance between each pair of seq
        generate heatmap of distance to help to decide the p_number
        return a dataframe having distance that is used for hierarchy.linkage
        '''
        with open(self.motif_file) as f:
            f.readline() # skip header
            ls_motif = [line.strip().split()[0] for line in f]
            print("There are {} motifs for cluster.".format(len(ls_motif)))

        df = pd.DataFrame(0, index=ls_motif, columns=ls_motif) 
        for i in df.columns.values: # df.index.values
            for j in df.index.values:
                df.loc[(i, j)] = self.CountDistance(i, j)

        # plot clustermap to help decide the number of clusters
        plt.figure()
        cg = sns.clustermap(df, method='single', metric='euclidean', 
                            row_cluster=True, col_cluster=True,
                            # cbar_kws={"ticks":[0, 2, 4, 6, 8], "shrink": .8}, # specify the ticks for cbar (distance)
                            cbar_pos=(0.1, .2, .03, .4)) # (left, bottom, width, height)
        
        # cluster but not show the dendrogram
        cg.ax_row_dendrogram.set_visible(False)
        cg.ax_col_dendrogram.set_visible(False)
        
        if not self.prefix:
            png = 'Motif_cluster_heatmap.pdf'
        else:
            png = '{}_Motif_cluster_heatmap.pdf'.format(self.prefix)
        plt.savefig(png, bbox_inches='tight')
        return df

    def ClusterMotif(self, df, p_number):
        '''
        generate dengram for cluster
        df is the dataframe having distance returned by CreateMotifDf().
        '''
        # the number of clusters is determined by p
        # Choose p for which the number of motifs in the largest cluster is around or just below 256. N^4=256.
        Z = hierarchy.linkage(df, method='single', metric='euclidean')
        plt.figure()
        dendrogram = hierarchy.dendrogram(
            Z,
            truncate_mode='lastp', p=p_number)  # show only the last p merged clusters
        if not self.prefix:
            png = 'Motif_cluster_dengram.png'
        else:
            png = '{}_Motif_cluster_dengram.png'.format(self.prefix)
        plt.savefig(png, bbox_inches='tight')

        # extract the motifs in each cluster
        # dendrogram['ivl'], e.g. ['(66)', '(64)', '(64)', '283', '(108)'], here 283 means one motif in this cluster, the (66) means 66 seq in cluster
        print('The dengram results: {}'.format(dendrogram['ivl']))
        ls_NumberInCluster = [int(re.findall('\((\d*)\)', item)[0]) if re.findall('\((\d*)\)', item) else 1 for item in dendrogram['ivl']] 
        # ls_NumberInCluster, e.g. [4, 12, 6, 2], 4 items in cluster 1, 12 items in cluster 2...
        # the items are ordered as: leaves_list(Z)

        # convert to index list for np.split
        ls_IndexForCluster = [sum(ls_NumberInCluster[:i]) for i in range(1,len(ls_NumberInCluster))] # # [4, 12, 6, 2] returns [4, 16, 22]

        # extract the motif seq in each cluster, the number of seq in each cluster corresponds to the Motif_cluster_dengram.png
        # The returning seq in ls can be used for weblogo plot: https://weblogo.berkeley.edu/logo.cgi
        with open(self.motif_file) as f:
            f.readline() # skip header
            line = f.readline()
            size = len(line.strip().split()[0]) # size of motif seq, 8bp motif size=8
        
        print('====The following is the cluster result===')
        ls = []
        i = 1
        for cluster in np.split(hierarchy.leaves_list(Z), ls_IndexForCluster, axis=0):
            for index in cluster:
                ls.append(df.columns.values[index])
            with open('{}_motifcluster{}'.format(self.prefix, i), 'w') as output:
                print('The number of motif seqs in the cluster{} is: {}'.format(i, len(ls)))
                print('\n'.join(ls), end='\n', file=output)
                base_matrix = self.CountBase(ls, size)
                print(base_matrix)
                print('\n')
                ls = []
            i+=1


##-------Parser
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--input', help='Significant Motif file generated by identifyMotif.py', dest='input_file')
    parser.add_argument('--number', help='cluster number for dendrogram', dest='p_number', type=int)
    parser.add_argument('--name', help='name prefix for the heatmap and dengram figures and cluster seq files', dest='prefix', required=False)
    
    args = parser.parse_args()


    temp = motif(args.input_file, args.prefix)
    motif_df = temp.CreateMotifDf()
    temp.ClusterMotif(motif_df, args.p_number)