# all cell line differential peak regions generated with, e.g. python score_comp_region.py 2000 2000 AHC4760 AHT254 ATAC_AHC4760_hg38_300_noproms_nc30.bed  ATAC_AHT254_hg38_300_noproms_nc30.bed ATAC_AHC4760_hg38_filtered_dedup.bigwig ATAC_AHT254_hg38_filtered_dedup.bigwig 1 0

import os, sys, os.path

def make_union_bed_label(label1,label2,region1,region2):
    f=region1.split('_')
    g=region2.split('_')
    genome='hg38'
    ofn="union_"+label1+"_"+label2+"_"+genome+".bed"

    ofile=open(ofn,'w')
    ifile=open(region1,'r')
    lines=ifile.readlines()
    for i in range(len(lines)):
        f=lines[i].split()
        ofile.write( "{0}\t{1}\t{2}\t{3}\n".format(f[0],f[1],f[2],label1+'_'+str(i+1)))
    ifile.close()
    ifile=open(region2,'r')
    lines=ifile.readlines()
    for i in range(len(lines)):
        f=lines[i].split()
        ofile.write( "{0}\t{1}\t{2}\t{3}\n".format(f[0],f[1],f[2],label2+'_'+str(i+1)))
    ofile.close()

import sys, os, math

def mean(dict1):    #dict
    x=[]
    for y in dict1.keys():
        x.append(dict1[y])
    m=0
    for z in x:
        m=m+z
    m=m/len(x)
    return m

def meanv(x):         #list
    m=0
    for z in x:
        m=m+z
    m=m/len(x)
    return m

def std(x):
    s=0
    m=meanv(x)
    for z in x:
        s=s+(z-m)**2
    return math.sqrt(s/(len(x)-1))

def corr(dict1,dict2):     #keys must match
    x=[]
    for z in dict1.keys():
        x.append(dict1[z])
    y=[]                  
    for z in dict1.keys():
        y.append(dict2[z])
    sx=std(x)
    sy=std(y)
    mx=meanv(x)
    my=meanv(y)
    r=0
    for i in range(len(x)):
        r=r+(x[i]-mx)*(y[i]-my)/sy/sx
    return r/(len(x)-1)

def linreg(xdict,ydict):
    x=[]
    y=[]
    for z in xdict.keys():
        x.append(xdict[z])
    for z in ydict.keys():
        y.append(ydict[z])
    mx=sum(x)/len(x)
    my=sum(y)/len(y)
    n1=0
    d1=0
    for i in range(len(x)):
        n1=n1+(x[i]-mx)*(y[i]-my)
        d1=d1+(x[i]-mx)*(x[i]-mx)
    return n1/d1,my-(n1/d1)*mx

def make_diff_sig(n1,n2,label1,label2,sfn1,sfn2,bfn1,bfn2,a0,b0):
    s1={}
    s2={} 

    genome='hg38'

    lfn=label1+"_"+label2+"_"+genome

    ifile=open(sfn1,'r')
    for line in ifile:
        f=line.split()
        s1[f[0]]=eval(f[4])
    m1=mean(s1)
    print(label1,'mean',m1)
    ifile=open(sfn2,'r')
    for line in ifile:
        f=line.split()
        s2[f[0]]=eval(f[4])
    m2=mean(s2)
    print(label2,'mean',m2)
    if a0==0 and b0==0:
            a,b=linreg(s1,s2)
    else:
            a=a0
            b=b0
            print("override regression",a,b)

    ifile=open(bfn1,'r')
    lines1=ifile.readlines()
    ifile=open(bfn2,'r')
    lines2=ifile.readlines()
    ifile.close()

#   the idea of this part of the algorithm is to find a cutoff t1 and t2
#   such that there are n1 elements with s2<a*s1+b-t1
#                   and n2 elements with s2>a*s1+b+t2
#   a better way is to list all possible t's,  sort,   and then find t which matches desired n

    t=[]
    for i in range(len(lines1)):
        f=lines1[i].split()
        lab=label1+'_'+str(i+1)
        t.append(s2[lab]-a*s1[lab]-b)   #vert dist from regr line
    t.sort()
    t1=t[n1]                            #t1 will most likely be negative
    k=n1
    
    print(t1,k,'{0} excl'.format(label1))

    t=[]
    for i in range(len(lines2)):
        f=lines2[i].split()
        lab=label2+'_'+str(i+1)
        t.append(s2[lab]-a*s1[lab]-b)   #vert dist from regr line
    t.sort()
    t2=t[-n2]                           #t2 will most likely be positive
    k=n2
    print(t2,k,'{0} excl'.format(label2))

    ofile1=open('{0}_excl_vs_{1}_{2}_300.bed'.format(label1,label2,genome),'w')
    ofile2=open('{1}_excl_vs_{0}_{2}_300.bed'.format(label1,label2,genome),'w')

    k=0
    for i in range(len(lines1)):
        f=lines1[i].split()
        lab=label1+'_'+str(i+1)
        if s2[lab]<a*s1[lab]+b+t1:
            k=k+1
            ofile1.write( "{0}\t{1}\t{2}\t{3}\n".format(f[0],f[1],f[2],s1[lab]))
    print(t1,k,'{0} excl'.format(label1))

    k=0
    for i in range(len(lines2)):
        f=lines2[i].split()
        lab=label2+'_'+str(i+1)
        if s2[lab]>a*s1[lab]+b+t2:
            k=k+1
            ofile2.write( "{0}\t{1}\t{2}\t{3}\n".format(f[0],f[1],f[2],s2[lab]))
    print(t2,k,'{0} excl'.format(label2))
    
    ofile=open(lfn+'_corr.out','w')
    print('corr:',corr(s1,s2))
    ofile.write(str(corr(s1,s2))+'\n')
    ofile.close()
    
    ofile=open(lfn+'_chip.R','w')
    ofile.write('a<-{0}\n'.format(a))
    ofile.write('b<-{0}\n'.format(b))
    ofile.write('x<-read.table(\"{0}\")\n'.format(sfn1))
    ofile.write('y<-read.table(\"{0}\")\n'.format(sfn2))
    ofile.write('ind<-which((y[[5]]< {0}*x[[5]]+{1}+{2}))\n'.format(a,b,t1))
    ofile.write('ind1<-which((y[[5]]> {0}*x[[5]]+{1}+{2}))\n'.format(a,b,t2))
    ofile.write('pdf(\"{0}\",5,5)\n'.format(lfn+'_chip.pdf'))
    ofile.write('par(mar=c(3,3,1,1)+0.1, mgp=c(1.6,0.6,0.0))\n')
    ofile.write('plot(x[[5]],y[[5]],pch=19,cex=.2,xlab=\"{0}\",ylab=\"{1}\",main=\"differential signal\")\n'.format(label1,label2))
#    ofile.write('plot(x[[5]],y[[5]],pch=19,cex=.2,xlab=\"{0}\",ylab=\"{1}\",main=\"differential signal\",ylim=c(0,100),xlim=c(0,100))\n'.format(label1,label2))
    ofile.write('points(x[[5]][ind],y[[5]][ind],pch=19,cex=.2,col=\"blue\")\n')
    ofile.write('points(x[[5]][ind1],y[[5]][ind1],pch=19,cex=.2,col=\"red\")\n')
    ofile.write('abline({0},{1})\n'.format(b,a))
    ofile.write('dev.off()\n')
    ofile.write('cat(paste(\"{0}_excl8:\",length(ind),\"\\n\"))\n'.format(label1))
    ofile.write('cat(paste(\"{0}_excl8:\",length(ind1),\"\\n\"))\n'.format(label2))
    ofile.write('cat(\'corr:\',cor(x[[5]],y[[5]]),\'\n\')')
    ofile.close()
    os.system('Rscript {0}'.format(lfn+'_chip.R'))
#    os.system('gs '+lfn+'_chip.pdf')

def main(argv=sys.argv):
        if len(argv)<9:
               print("usage:  {0} n1 n2 label1 label2 bedfile1 bedfile2 wigfile1 wigfile2".format(argv[0]))
               exit()
        n1=eval(argv[1])
        n2=eval(argv[2])

        label1=argv[3]
        label2=argv[4]
        
        region1=argv[5]
        region2=argv[6]

        signal1=argv[7]
        signal2=argv[8]

        if len(argv)>9:
           a=eval(argv[9])
           b=eval(argv[10])
           print("override regression",a,b)

        # make union of regions

        genome='hg38'

        ufn="union_"+label1+"_"+label2+"_"+genome+".bed"

        f=signal1.split("_")
        g=signal2.split("_")
        ind=0
        while f[ind]==g[ind]:
                ind=ind+1

        make_union_bed_label(label1,label2,region1,region2)
        sfn1="union_"+label1+"_"+label2+"_"+genome+"_"+f[ind]+"_scores.out"
        if os.path.isfile(sfn1):
            print( sfn1,' exists, reusing.  remove to recalculate.')
        else:
            os.system('~/work/scripts/bigWigAverageOverBed {0} {1} {2}'.format(signal1,ufn,sfn1))
        sfn2="union_"+label1+"_"+label2+"_"+genome+"_"+g[ind]+"_scores.out"
        if os.path.isfile(sfn2):
            print( sfn2,' exists, reusing.  remove to recalculate.')
        else:
            os.system('~/work/scripts/bigWigAverageOverBed {0} {1} {2}'.format(signal2,ufn,sfn2))

        if len(argv)>9:
            make_diff_sig(n1,n2,label1,label2,sfn1,sfn2,region1,region2,a,b)
        else:
            make_diff_sig(n1,n2,label1,label2,sfn1,sfn2,region1,region2,0,0)

main()
