"""
minjiezhang123@gmail.com    2020-03-16      python 3
This script can be used to plot the ROC curve for 28S rRNA.

First, we need get the base pair info from PDB database; (eg, 4V6X: Structure of the human 80S ribosome)
    I already build the arc file to store base pairing information. 
Second, the target genome will be split to small bins (5nt or 10nt/each bin);
Third, the real basepair info of pairwise will be stored as a matrix;

"""


#1. input and output setup
################################################################################
################################################################################
#this section sets up the input and output
import sys, argparse, numpy, os, re, itertools, random, math
from datetime import datetime
from sklearn.metrics import roc_curve, auc
import matplotlib as mpl  
import matplotlib.pyplot as plt
from math import floor, ceil
from itertools import product
from Bio.PDB.PDBParser import PDBParser
import seaborn as sns
from collections import Counter

if len(sys.argv) < 3:
    print("Usage:           python plotROC.py  sam_crssant  basepair_arc_file  cutoff  winbin  DGorReads  outputprefix")
    print("sam_crssant:     sam file generated by crssant.py        ")
    print("basepair_arc_file:    basepair infomation (arc file)     ")
    print("cutoff:          Reads number cutoff of pairwise bins used for analysis, only the pairwise bins containing more the the cutoff value used for analysis   ")
    print("winbin:          window bins                             ")
    print("DGorReads:       Using DG or Reads for plot the figure   ")
    print("outputprefix                                             ")
    sys.exit()

inputsam = open(sys.argv[1],'r')
arcfile = open(sys.argv[2],'r')
reads_cutoff = int(sys.argv[3])
winbin = int(sys.argv[4])
DGorReads =  sys.argv[5]
outputprefix = sys.argv[6]
################################################################################



#2. subfunctions
################################################################################
def plot_roc(labels, predict_prob, outputprefix):
    ax = plt.subplot(1,1,1)
    false_positive_rate,true_positive_rate,thresholds=roc_curve(labels, predict_prob)
    #roc_auc=auc(labels, predict_prob)
    roc_auc=auc(false_positive_rate, true_positive_rate)
    plt.title('ROC')
    plt.plot(false_positive_rate, true_positive_rate,'b',label='AUC = %0.4f'% roc_auc)
    plt.legend(loc='lower right')
    plt.plot([0,1],[0,1],'r--')
    plt.ylabel('True Positive Rate(%)')
    plt.xlabel('False Positive Rate(%)')
    plt.savefig(outputprefix+'.pdf')
    plt.show()
# labels=[0,0,0,0,1,1,1]
# predict_prob=[0.3,0.2,0.7,0.5,0.4,0.9,0.6]
# plot_roc(labels, predict_prob)
    

def timenow(): return str(datetime.now())[:-7]

def mergeCIGAR(CIGAR): 
    #merge all operations that consume the reference, i.e. MI=X
    #example: 1S2M3N4M5I6M7S -> 1S2M3N10M7S 
    ops = re.findall('\d+[MNISH=X]', CIGAR) #all that consume query
    newops = [ops[0]]
    for op in ops[1:]: #concatenate all internal ops that consume query [MIS=X
        if op[-1] not in "I=X":
            if newops[-1][-1]=="M" and op[-1]=="M":
                newops[-1] = str(int(newops[-1][:-1])+int(op[:-1]))+"M"
            else: newops.append(op)
    newCIGAR = ("".join(str(i) for i in newops))
    return newCIGAR


def getsegs(line, segs): #turn an alignment into segs [(RNAME, STRAND L, R), ...]
    align = line.split()
    RNAME, POS, CIGAR = align[2], int(align[3]), align[5]
    STRAND = '-' if '{0:012b}'.format(int(align[1]))[-5] == '1' else '+'
    if align[0] not in segs: segs[align[0]] = [] #store all gaps from this CIGAR string, each as a 3-tuple.
    seglens = [int(i[:-1]) for i in re.findall('[0-9]+M', mergeCIGAR(CIGAR))] #seg lengths
    Ns =[i.rstrip('0123456789') for i in mergeCIGAR(CIGAR).split('M')]
    rx = [] #reference consumed: MD=X
    for N in Ns:
        rx.append(sum([int(i[:-1]) for i in re.findall('[0-9]+[ND=X]', N)]))
    for i in range(len(seglens)): #combine ref and segment lengths to make the junctions
        l, r = POS+sum(rx[:i+1])+sum(seglens[:i]), POS+sum(rx[:i+1])+sum(seglens[:i+1])-1
        if (RNAME, int(l), int(r), STRAND) not in segs[align[0]]: segs[align[0]].append((RNAME, int(l), int(r), STRAND))
        segs[align[0]].sort()
    #print(segs[align[0]])
    if len(segs[align[0]]) == 2: return segs[align[0]]
    
    
def normdist(list1,list2):
    mu1=np.mean(list1); sigma1 =np.std(list1)
    mu2=np.mean(list2); sigma2 =np.std(list2)
    overlap = NormalDist(mu1, sigma1).overlap(NormalDist(mu2, sigma2))
    return overlap


def getDGspan(DGspan):
    #DGspan['12S,12S,0']=[[('hs12S','2','22','+'),('hs12S','270','295','+')], [('hs12S','3','24','+'),('hs12S','285','300','+')]]
    #DGspan['12S,12S,1']=[[('hs12S', 8, 20, '+'), ('hs12S', 108, 137, '+')], [('hs12S', 6, 24, '+'), ('hs12S', 106, 145, '+')]]
    for DG in DGspan:
        DGslist=DGspan[DG]
        #print(DG)
        #print(DGspan[DG][0],DGspan[DG][-1])
        segls, segle, segrs, segre = int(DGslist[0][0][1]), int(DGslist[0][0][2]), int(DGslist[0][1][1]), int(DGslist[0][1][2])
        for i in range(1,len(DGslist),1):
            if int(DGslist[i][0][1]) < segls: segls=int(DGslist[i][0][1])
            if int(DGslist[i][0][2]) > segle: segle=int(DGslist[i][0][2])
            if int(DGslist[i][1][1]) < segrs: segrs=int(DGslist[i][1][1])
            if int(DGslist[i][1][2]) > segre: segre=int(DGslist[i][1][2])
        DGspan[DG]=[[DGslist[0][0][0],segls,segle,len(DGslist)],[DGslist[0][1][0],segrs,segre,len(DGslist)]]
    return DGspan
    
    
def getOverlap(a, b): 
    overlap = int(min(int(a[2]),int(b[2]))) - int(max(int(a[1]),int(b[1])))
    readslen = (int(b[2]) - int(b[1]))/2
    if overlap >1 and overlap >= readslen: seganno = "Ture"
    else: seganno = "False"
    return seganno


def coveragenorm(dglist,dgspan,cutoff):
    #dglist: [[chr1,1,15,+],[chr1,31,45,+]], [[chr1,2,13,+],[chr1,33,43,+]], [[chr1,5,20,+],[chr1,40,50,+]]...
    #dgspan: [[chr1,1,25,25],[chr1,30,50,25]]]
    #cutoff: 0.8
    dgspannew=[]; count=-1
    for arm in dgspan:
        poscoverage = {};
        dgstart=0; dgend=0; count+=1
        for i in range(int(arm[1]),int(arm[2])+1,1): poscoverage[i] = 0 #initialize the pos coverage
        
        for seg in dglist:
            if getOverlap(arm, seg[count]) == "Ture":
                for i in range(int(seg[count][1]),int(seg[count][2])+1,1): poscoverage[i] += 1
        for i in range(int(arm[1]),int(arm[2])+1,1):
            if poscoverage[i]/max(poscoverage.values()) >= cutoff:  dgstart = i; break;
        for i in range(int(arm[2]),int(arm[1]),-1):
            if poscoverage[i]/max(poscoverage.values()) >= cutoff:  dgend = i; break;
        len = dgend - dgstart + 1
        dgspannew.append([arm[0],str(dgstart),str(dgend), arm[3]])
    return dgspannew


# get trimmean arm length (trimmean: 0.7)
def get_armlen(newDGspan):
    # newDGspan:    {'1HR2,1HR2,0': [['1HR2', '2', '9', '+'], ['1HR2', '28', '33', '+']]}
    larmlen = []; rarmlen = []
    for DG in newDGspan:
        larmlen.append(int(newDGspan[DG][0][2])-int(newDGspan[DG][0][1])+1)
        rarmlen.append(int(newDGspan[DG][1][2])-int(newDGspan[DG][1][1])+1)
    larmlen.sort()
    rarmlen.sort()
    leftlen = int(larmlen[int(len(larmlen)*0.8)]/5)*5
    rightlen = int(larmlen[int(len(rarmlen)*0.8)]/5)*5
    return leftlen,rightlen



## test whether bed1 was included by bed2:
def bedinregions(bed, regions):
    # bed = ['hs45S', '3998', '4012', '+']
    # regions = [['hs45S', '3654', '5523'], ['hs45S', '6600', '6757'], ['hs45S', '7924', '12994']]
    bedanno = 'False'
    for targetbed in regions:
        if str(bed[0]) == str(targetbed[0]):
            if int(bed[1]) >= int(targetbed[1]) and int(bed[2]) <= int(targetbed[2]):
                bedanno = 'True'
                #print(bedanno)
                continue
    return bedanno

 
## get newDGspan with two arms located in target regions:
# newDGspan:    {'1HR2,1HR2,0': [['1HR2', '2', '9', '+'], ['1HR2', '28', '33', '+']]}
def get_targetDG(newDGspan, targetregion):
    #targetregion: hs45S:3654:5523,hs45S:6600:6757,hs45S:7924:12994
    regions = []; newDGspan_filter = {}
    for i in targetregion.split(','): regions.append(i.split(':'))
    for DG in newDGspan:
        if bedinregions(newDGspan[DG][0], regions) == 'True' and bedinregions(newDGspan[DG][1], regions) == 'True':
            if DG not in newDGspan_filter: newDGspan_filter[DG] = []
            newDGspan_filter[DG] = newDGspan[DG]
    return newDGspan_filter
        
        
## classfy DGs to DG_leftarm, DG_rightarm, DG_LRarm
# DG_leftarm:  DGs Left arm contain ss nts at target region
# DG_rightarm: DGs right arm contain ss nts at target region
# DG_LRarm:    DGs left && right arm contaiin ss nts 
def select_DG(newDGspan, bplist, regions, regione):
    # newDGspan:    {'1HR2,1HR2,0': [['1HR2', '2', '9', '+'], ['1HR2', '28', '33', '+']]}
    # bpinfo:   [[1HR2, 6, 112], [1HR2, 7, 111]...]
    DG_leftarm = {}; DG_rightarm = {};  DG_LRarm = {}
    for DG in newDGspan:
        if str(newDGspan[DG][0][0])=='hs45S' and str(newDGspan[DG][1][0])=='hs45S':
            # left arm
            for i in range(int(newDGspan[DG][0][2])-regions+1, int(newDGspan[DG][0][2])-regione, -1):
                if bplist[i] == 0: DG_leftarm[DG] = newDGspan[DG]; continue
            for j in range(int(newDGspan[DG][1][2])-regions+1, int(newDGspan[DG][1][2])-regione, -1):
                if bplist[j] == 0: DG_rightarm[DG] = newDGspan[DG]; continue
        if DG in DG_leftarm and DG in DG_rightarm:  
            DG_LRarm[DG] = newDGspan[DG]
    #print(DG_leftarm)
    #print(DG_rightarm)
    #print(DG_LRarm)
    return DG_leftarm, DG_rightarm, DG_LRarm


# get PDB average distance based on winbin
def get_PDB_aver_distance(matrix_PDB_distance, start, end, winbin):
    numcol=ceil((int(end)-int(start)+1)/winbin)
    numrow=ceil((int(end)-int(start)+1)/winbin)
    matrix_PDB_aver_distance_win = numpy.zeros((numcol,numrow))
    for i in range(1, numcol+1, 1):
        for j in range(1, numrow+1, 1):
            distances = []; col_list = []; row_list = []
            if i!= j:
                col_list = [x for x in range(start+(i-1)*winbin,start+i*winbin)]
                row_list = [x for x in range(start+(j-1)*winbin,start+j*winbin)]
                for x in list(itertools.product(col_list, row_list)):
                    if matrix_PDB_distance[x[0]-1, x[1]-1] == -1: matrix_PDB_aver_distance_win[i-1, j-1] = -1
                    elif matrix_PDB_distance[x[0]-1, x[1]-1] > 0: distances.append(matrix_PDB_distance[x[0]-1, x[1]-1])
                if distances:   matrix_PDB_aver_distance_win[i-1, j-1] = numpy.mean(distances)
    return matrix_PDB_aver_distance_win
"""
start = 8225
end = 8242
winbin = 5
matrix_aver_distance_win = get_aver_distance(matrix_PDB_distance, start, end, winbin)
"""
################################################################################


#3 start processing the sam file.
################################################################################
print(timenow()+" Collecting alignments into a dictionary ...")
dgdict = {}; #{sample-dgid: [(RNAME, str(ls), str(le), STRAND),(RNAME, str(rs), str(re), STRAND)] [] []..}
DGspan = {}; DGs=[];  segs={}; newDGspan = {}; newDGspan_filter = {}

if DGorReads == "DG":
    for line in inputsam:
        #build the dgdict: dictionary of all segment terminals
        #format: dgid:[RNAME, STRAND, lstart, lend, rstart rend]
        if line[0]=='@': continue
        align = line.split('\t')
        dgloc = len(align); dgid = None #check where the dgid is
        if align[dgloc-2][:2]=='DG': dgid = (align[dgloc-2].split(':')[-1])
        else: sys.exit("Error, DG ID not found in column 19")
        #sample = align[0].split('-')[-1]
        if dgid not in DGs: DGs.append(dgid)
        #dgid_sample = sample+'-'+dgid
        if dgid not in dgdict: dgdict[dgid]=[] #initialize the item
        if getsegs(line,segs):  dgdict[dgid].append(getsegs(line,segs))
        if dgid not in DGspan: DGspan[dgid]=[]
        if getsegs(line,segs):  DGspan[dgid].append(getsegs(line,segs))
    inputsam.close()

    ## get DGspan of raw data
    DGspan = getDGspan(DGspan)

    ## get the common region of each DG
    for DG in DGspan:
        if len(dgdict[DG]) >= minreads:  # filter DG based on DG_reads_cutoff parameter
            newDGspan[DG] = coveragenorm(dgdict[DG],DGspan[DG],DGcommoncutoff)

    ## filter newDGspan that located in targe regions:
    #newDGspan_filter = get_targetDG(newDGspan, 'hs45S:3654:5523,hs45S:6600:6757,hs45S:7924:12994')
    newDGspan_filter = get_targetDG(newDGspan, 'hs45S:7924:12994')


if DGorReads == "Reads":
    for line in inputsam:
        #build the dgdict: dictionary of all segment terminals
        #format: dgid:[RNAME, STRAND, lstart, lend, rstart rend]
        if line[0]=='@': continue
        align = line.split('\t')
        readID = align[0]
        if readID not in newDGspan: newDGspan[readID] = []
        if getsegs(line,segs):  newDGspan[readID] = getsegs(line,segs)
    inputsam.close()
    ## filter newDGspan that located in targe regions:
    #newDGspan_filter = get_targetDG(newDGspan, 'hs45S:3654:5523,hs45S:6600:6757,hs45S:7924:12994')
    newDGspan_filter = get_targetDG(newDGspan, 'hs45S:7924:12994')
################################################################################



#4. Get base pairing and noBP information:
################################################################################
print(str(datetime.now())[:-7], "reading bedpe arc file ...")
numcol = 13357; numrow = 13357
rRNA_BP_dict = numpy.zeros((13357,13357))
    
## read dsRNA information (bp_acr_file) 
bpinfo = {} #[[1HR2, 6, 112], [1HR2, 7, 111]...]
lines = arcfile.readlines()[1:] # remove header
for line in lines:
    #track graphType=arc
    #1HR2    6       112     stemloop1
    chr, start, end, stem = line.rstrip('\n').split('\t')
    rRNA_BP_dict[int(start)-1, int(end)-1] = 1
    rRNA_BP_dict[int(end)-1, int(start)-1] = 1
arcfile.close()
lines = []


def get_PDB_aver_distance(matrix_PDB_distance, start, end, winbin):
    numcol=ceil((int(end)-int(start)+1)/winbin)
    numrow=ceil((int(end)-int(start)+1)/winbin)
    matrix_PDB_aver_distance_win = numpy.zeros((numcol,numrow))
    for i in range(1, numcol+1, 1):
        for j in range(1, numrow+1, 1):
            bpcount = 0; col_list = []; row_list = []
            if i!= j:
                col_list = [x for x in range(start+(i-1)*winbin,start+i*winbin)]
                row_list = [x for x in range(start+(j-1)*winbin,start+j*winbin)]
                for x in list(itertools.product(col_list, row_list)):
                    if matrix_PDB_distance[x[0]-1, x[1]-1] == 1: bpcount += 1
                if bpcount>=2:   matrix_PDB_aver_distance_win[i-1, j-1] = 1
    return matrix_PDB_aver_distance_win


## get bp info based on winbin
rRNA_BP_winbin_dict = get_PDB_aver_distance(rRNA_BP_dict, 7925, 12996, winbin)
################################################################################


# process the data coverage
################################################################################
def getPos(bed1, bed2, winbin,start):
    chr1, start1, end1 = bed1[0], int(bed1[1])-int(start)+1, int(bed1[2])-int(start)+1
    chr2, start2, end2 = bed2[0], int(bed2[1])-int(start)+1, int(bed2[2])-int(start)+1
    pos1 = []; pos2 = []; pos1new = []; pos2new = []
    for i in range(ceil(start1/winbin), ceil(end1/winbin)+1, 1):
        pos1.append(i)
    for i in range(ceil(start2/winbin), ceil(end2/winbin)+1, 1):
        pos2.append(i)
    return list(itertools.product(pos1,pos2))

def GetMiddleMatchedPos(bed1, bed2, winbin,start):
    chr1, start1, end1 = bed1[0], int(bed1[1])-int(start)+1, int(bed1[2])-int(start)+1
    chr2, start2, end2 = bed2[0], int(bed2[1])-int(start)+1, int(bed2[2])-int(start)+1
    pos1 = []; pos2 = []; pos1new = []; pos2new = []
    for i in range(ceil(start1/winbin), ceil(end1/winbin)+1, 1):
        pos1.append(i)
    for i in range(ceil(start2/winbin), ceil(end2/winbin)+1, 1):
        pos2.append(i)
    if len(pos1) <= len(pos2):
        pos2.reverse()
        return list(zip(pos1, pos2[int((len(pos2)-len(pos1))/2):]))
    elif len(pos1) > len(pos2):
        pos1.reverse()
        return list(zip(pos1[int((len(pos1)-len(pos2))/2):], pos2))
    
    
start = 7925; end = 12996
numcol=ceil((int(end)-int(start)+1)/winbin)
numrow=ceil((int(end)-int(start)+1)/winbin)
matrix_data_bp_win = numpy.zeros((numcol,numrow))

print(timenow()+" getting base pair information ...")
for DG in newDGspan_filter:
    bp_win = 'no'; miss_value = ''
    for i in getPos(newDGspan_filter[DG][0], newDGspan_filter[DG][1], winbin, 7925):
        if rRNA_BP_winbin_dict[i[0]-1, i[1]-1] == 1:
            bp_win = 'yes';
            matrix_data_bp_win[i[0]-1, i[1]-1] += 1
    if bp_win == 'no':
        for i in GetMiddleMatchedPos(newDGspan_filter[DG][0], newDGspan_filter[DG][1], winbin, 7925):
            matrix_data_bp_win[i[0]-1, i[1]-1] += 1


# plot ROC curve
################################################################################            
"""
score=[]; classify = [] # True or False;       
for i in range(0, numcol, 1):
    for j in range(i, numrow, 1):
        if matrix_data_bp_win[i,j] >=reads_cutoff:
            if rRNA_BP_winbin_dict[i, j] == 0:   classify.append(0);   score.append(matrix_data_bp_win[i,j])
            elif rRNA_BP_winbin_dict[i, j] == 1: classify.append(1);   score.append(matrix_data_bp_win[i,j])

# plot ROC curve
plot_roc(classify, score, outputprefix+'_ROC')
################################################################################
"""

# plot heatmap
################################################################################
cdict1= {'red':   [(0.0,  1.0, 1.0),(0.4,  0.7, 0.7),(0.7,  0.3, 0.3),(1.0,  0.0, 0.0)],
         'green': [(0.0,  1.0, 1.0),(0.4,  0.7, 0.7),(0.7,  0.3, 0.3),(1.0,  0.0, 0.0)],
         'blue':  [(0.0,  1.0, 1.0),(0.4,  0.7, 0.7),(0.7,  0.3, 0.3),(1.0,  0.0, 0.0)]}
cmp1 = mpl.colors.LinearSegmentedColormap('name',cdict1)

cdict2= {'red':   [(0.0, 0.0, 0.0),(0.8,  0.3, 0.3),(0.9,  0.5, 0.5),(1.0,  1.0, 1.0)],
         'green': [(0.0, 0.0, 0.0),(0.8,  0.3, 0.3),(0.9,  0.5, 0.5),(1.0,  1.0, 1.0)],
         'blue':  [(0.0, 0.0, 0.0),(0.8,  0.3, 0.3),(0.9,  0.5, 0.5),(1.0,  1.0, 1.0)]}
cmp2 = mpl.colors.LinearSegmentedColormap('name',cdict2)


def plot_heatmap(matrix_aver_distance_win, matrix_aver_DG_coverage, start, end, maxvalue):
    axes = plt.subplot(1,1,1)
    mask = numpy.zeros_like(matrix_aver_distance_win)
    mask[numpy.triu_indices_from(mask)] = True
    sns.heatmap(matrix_aver_distance_win,cmap=cmp1, mask=mask, vmin=0, vmax=1, zorder=10)
    sns.heatmap(matrix_aver_DG_coverage,cmap=cmp1, vmin=50, vmax=maxvalue, zorder=10)
    for x in range(0, ceil(((int(end)-int(start)+1)/winbin)+(800/winbin)), int(800/winbin)): plt.vlines(x, 0, int(ceil(int(end)-int(start)+1)/winbin), color = "silver", linestyles = "solid", linewidth=1, zorder=15)
    for x in range(0, ceil(((int(end)-int(start)+1)/winbin)+(800/winbin)), int(800/winbin)): plt.hlines(x, 0, int(ceil(int(end)-int(start)+1)/winbin), color = "silver", linestyles = "solid", linewidth=1, zorder=15)
    plt.plot([0, ceil((int(end)-int(start)+1)/winbin)],[0, ceil((int(end)-int(start)+1)/winbin)], linewidth=1, color = 'silver', zorder=15)
    sns.despine(top=False, right=False, left=False, bottom=False)
    axes.xaxis.set_ticklabels([])
    axes.yaxis.set_ticklabels([])
    plt.xticks([])
    plt.yticks([])
    plt.show()
    #plt.savefig(outputname+"A.pdf")
    #plt.close()
    return


plot_heatmap(rRNA_BP_winbin_dict,matrix_data_bp_win, 7925, 12996, 400)