#!/usr/bin/python
# -*- coding: utf-8 -*-
## Supprime les blocs duplique et les blocs sans ortho chez les outgroups

import re
import n11PairsFile,n12Deletion

def blocbloc2bloc(blocs,blocs2):
   res=[]
   for i in range(len(blocs)):
      s=blocs[i][1]
      if s:
         for b in blocs2[i]:
            res.append(range(b[0],b[1]+1))
      else:
         for j in range(len(blocs2[i])-1,-1,-1):
            res.append(range(blocs2[i][j][0],blocs2[i][j][1]+1))
   return res

### clean the different dictionary from deleted genes
def update(blocs,blocs1,blocs2,old2news1,old2news2,dicoOrthos1,dicoOrthos2,new2old1,new2old2):
   genes1=reduce(lambda x,y:x+y,map(lambda x:range(x[0],x[1]+1),reduce(lambda x,y:x+y,blocs1)))
   genes2=reduce(lambda x,y:x+y,map(lambda x:range(x[0],x[1]+1),reduce(lambda x,y:x+y,blocs2)))
   # update new2old
   for g in new2old1.keys():
      if g not in genes1 and g in new2old1:
         del new2old1[g]
         del dicoOrthos1[g]
   for g in new2old2.keys():
      if g not in genes2 and g in new2old2:
         del new2old2[g]
         del dicoOrthos2[g]
   # update old2news
   for o in old2news1.keys():
      for n in range(len(old2news1[o])-1,-1,-1):
         if old2news1[o][n][4] not in genes1:
            del old2news1[o][n]
      if len(old2news1[o])==0:
         del old2news1[o]
   for o in old2news2.keys():
      for n in range(len(old2news2[o])-1,-1,-1):
         if old2news2[o][n][4] not in genes2:
            del old2news2[o][n]
      if len(old2news2[o])==0:
         del old2news2[o]
   genes1=map(lambda x:range(x[0],x[1]+1),reduce(lambda x,y:x+y,blocs1))
   genes2=blocbloc2bloc(blocs,blocs2)
   # update dicoOrthos
   for g1 in dicoOrthos1.keys():
      ind=genes1.index(filter(lambda x:g1 in x,genes1)[0])
      for (sim,g2) in dicoOrthos1[g1]:
         if g2 not in genes2[ind]:
            dicoOrthos1[g1].remove((sim,g2))
   for g2 in dicoOrthos2.keys():
      ind=genes2.index(filter(lambda x:g2 in x,genes2)[0])
      for (sim,g1) in dicoOrthos2[g2]:
         if g1 not in genes1[ind]:
            dicoOrthos2[g2].remove((sim,g1))

## Fusion of the small blocks following each other
def fusion(blocs,blocs1,blocs2):
   genes1=reduce(lambda x,y:x+y,map(lambda x:range(x[0],x[1]+1),reduce(lambda x,y:x+y,blocs1)))
   genes2=reduce(lambda x,y:x+y,map(lambda x:range(x[0],x[1]+1),reduce(lambda x,y:x+y,blocs1)))
   i=0
   while i<len(blocs)-1:
      bi=blocs[i]
      bi1=blocs[i+1]
      if bi[0]==bi1[0] and bi[1]==bi1[1]: # same chrom and same sign
         if bi[1]==1:  # positif
            if blocs2[i][-1][1]>blocs2[i+1][0][0]:
               nogene=0
            else:
               nogene=1
               for g in range(blocs2[i][-1][1]+1,blocs2[i+1][0][0]):
                  if g in genes2:
                     nogene=0
                     break
            if nogene:
               newbloc2=blocs2[i]+blocs2[i+1]
               del blocs2[i],blocs2[i]
               blocs2.insert(i,newbloc2)
         else:  # negatif
            if blocs2[i+1][-1][1]>blocs2[i][0][0]:
               nogene=0
            else:
               nogene=1
               for g in range(blocs2[i+1][-1][1]+1,blocs2[i][0][0]):
                  if g in genes2:
                     nogene=0
                     break
            if nogene:
               newbloc2=blocs2[i+1]+blocs2[i]
               del blocs2[i],blocs2[i]
               blocs2.insert(i,newbloc2)
         if nogene:
            newbloc=[bi[0],bi[1],(bi[2]*bi[3]+bi1[2]*bi1[3])/(bi[3]+bi1[3]),bi[3]+bi1[3]]
            del blocs[i],blocs[i]
            blocs.insert(i,newbloc)
            newbloc1=blocs1[i]+blocs1[i+1]
            del blocs1[i],blocs1[i]
            blocs1.insert(i,newbloc1)
         else:
            i+=1
      else:
         i+=1
   return blocs,blocs1,blocs2

## For the old genes with several new genes we keep one!
def erase_double(old2news1,blocs1,blocs2,ortho1,blocs):
   if len(filter(lambda x:len(x)>2,old2news1.values())):
      print 'ERROR IN ERASE DOUBLE',filter(lambda x:len(x)>2,old2news1.values())
   toChoose1=[]
   for o in old2news1:
      if len(old2news1[o])>=2:
         toChoose1.append((o,old2news1[o][0][4],old2news1[o][1][4]))
   extremities1=reduce(lambda x,y:x+y,map(lambda x:[x[0][0],x[-1][1]],blocs1))
   for (o,d1,d2) in toChoose1:
      if d1 in extremities1 and d2 not in extremities1:
         deleteD(d1,blocs1,blocs2,ortho1,blocs)
      elif d1 not in extremities1 and d2 in extremities1:
         deleteD(d2,blocs1,blocs2,ortho1,blocs)
      else:
         if ortho1[d1][0][0]<ortho1[d2][0][0]:
            deleteD(d1,blocs1,blocs2,ortho1,blocs)
         else:
            deleteD(d2,blocs1,blocs2,ortho1,blocs)

def deleteD(d1,blocs1,blocs2,ortho1,blocs):
   i=0
   found=0
   while not found:
      for j in range(len(blocs1[i])):
         ssb=blocs1[i][j]
         if d1 in range(ssb[0],ssb[1]+1):
            found=1
            if blocs[i][3]==1:
               del blocs[i]
               del blocs1[i]
               del blocs2[i]
            else:
               blocs[i]=[blocs[i][0],blocs[i][1],blocs[i][2],blocs[i][3]-1]
               if ssb[0]==ssb[1]:
                  del blocs1[i][j]
               elif d1==ssb[0]:
                  blocs1[i][j]=[ssb[0]+1,ssb[1]]
               elif d1==ssb[1]:
                  blocs1[i][j]=[ssb[0],ssb[1]-1]
               else:
                  bj=[ssb[0],d1-1]
                  bj1=[d1+1,ssb[1]]
                  blocs1[i][j]=bj
                  blocs1[i].insert(j+1,bj1)
               if len(ortho1[d1])>1:
                  print 'ERROR len(ortho1[d1])>1',d1,ortho1[d1]
               deleteD2(ortho1[d1][0][1],i,blocs2)
            break
      i+=1

def deleteD2(d2,i,blocs2):
   for j in range(len(blocs2[i])):
      ssb=blocs2[i][j]
      if d2 in range(ssb[0],ssb[1]+1):
         if ssb[0]==ssb[1]:
            del blocs2[i][j]
         elif d2==ssb[0]:
            blocs2[i][j]=[ssb[0]+1,ssb[1]]
         elif d2==ssb[1]:
            blocs2[i][j]=[ssb[0],ssb[1]-1]
         else:
            bj=[ssb[0],d2-1]
            bj1=[d2+1,ssb[1]]
            blocs2[i][j]=bj
            blocs2[i].insert(j+1,bj1)
         break


###############
## We delete blocks without homo in outgroups
# old2news1={oldID:name,newchr,+/-,newID/chr,newID,m�me signe ou pas} from the ancestor
# sign1={oldID:'+'ou'-'} from the actual KLLA.def
# new2old1={newIDinsynthon:oldID}
def syntheseHomo(group,pathPart1,shNameList,old2news1,old2news2,sign1,sign2,new2old1,new2old2):
   name1=shNameList[0]
   name2=shNameList[1]
   dicosList1=[]
   dicosList2=[]
   for i in range(2,len(shNameList)):
      name3=shNameList[i]
      ## sign3={oldID:'+' or '-'}
      sign3=n11PairsFile.oldChRead('../../'+group+'/01Genomes/'+name3+'.def')
      dicoG1G3={}
      if name1<name3:
         filePairs13=open(pathPart1+name1+'.'+name3+'.orth.pairs','r')
         for ligne in filePairs13:
            ligneList=re.split(" ",ligne)
            ## gname1=oldID1
            gname1=int(ligneList[4])
            if gname1 in old2news1:
                  if len(old2news1[gname1])!=1:
                     print 'ERRRRRRROR'
                  newg1=old2news1[gname1][0]
                  s=((sign1[gname1]==sign3[int(ligneList[9])]) == newg1[5])
                  if gname1 not in dicoG1G3:
                     dicoG1G3[gname1]=[(int(ligneList[9]),s)]
                  else:
                     dicoG1G3[gname1].append((int(ligneList[9]),s))
      else:
         filePairs13=open(pathPart1+name3+'.'+name1+'.orth.pairs','r')
         for ligne in filePairs13:
            ligneList=re.split(" ",ligne)
            gname1=int(ligneList[9])
            if gname1 in old2news1:
               if len(old2news1[gname1])!=1:
                  print 'ERRRRRRROR'
               newg1=old2news1[gname1][0]
               s=((sign1[gname1]==sign3[int(ligneList[4])]) == newg1[5])
               if gname1 not in dicoG1G3:
                  dicoG1G3[gname1]=[(int(ligneList[4]),s)]
               else:
                  dicoG1G3[gname1].append((int(ligneList[4]),s))
      dicosList1.append(dicoG1G3)
      dicoG2G3={}
      if name2<name3:
         filePairs23=open(pathPart1+name2+'.'+name3+'.orth.pairs','r')
         for ligne in filePairs23:
            ligneList=re.split(" ",ligne)
            gname1=int(ligneList[4])
            if gname1 in old2news2:
               newg1=old2news2[gname1][0]
               s=((sign2[gname1]==sign3[int(ligneList[9])]) == newg1[5])
               if gname1 not in dicoG2G3:
                  dicoG2G3[gname1]=[(int(ligneList[9]),s)]
               else:
                  dicoG2G3[gname1].append((int(ligneList[9]),s))
      else:
         filePairs23=open(pathPart1+name3+'.'+name2+'.orth.pairs','r')
         for ligne in filePairs23:
            ligneList=re.split(" ",ligne)
            gname1=int(ligneList[9])
            if gname1 in old2news2:
               newg1=old2news2[gname1][0]
               s=((sign2[gname1]==sign3[int(ligneList[4])]) == newg1[5])
               if gname1 not in dicoG2G3:
                  dicoG2G3[gname1]=[(int(ligneList[4]),s)]
               else:
                  dicoG2G3[gname1].append((int(ligneList[4]),s))
      dicosList2.append(dicoG2G3)
   ## dicoList1[oldID1]=(oldID3, {newID1: sameSign})
   return dicosList1,dicosList2    


## 3 ## deletion bloc without ortho
def delBlocdeBlocs(listRemove,blocs,blocs1,blocs2,ortho1,ortho2):
   listRemove.sort(reverse=True)
   for i in listRemove:
      for k in range(len(blocs1[i])):
         for j in range(blocs1[i][k][1]-blocs1[i][k][0]+1):
            g1=blocs1[i][k][0]+j
            if blocs[i][1]:
               g2=blocs2[i][k][0]+j
            else:
               g2=blocs2[i][len(blocs2[i])-1-k][1]-j
            simg2=filter(lambda x:x[1]==g2,ortho1[g1])[0]
            ortho1[g1].remove(simg2)
            simg1=filter(lambda x:x[1]==g1,ortho2[g2])[0]
            ortho2[g2].remove(simg1)
      blocs.pop(i)
      blocs1.pop(i)
      blocs2.pop(i)
   return blocs,blocs1,blocs2
def delWHomo(dicosList1,dicosList2,blocs,blocs1,blocs2,new2old1,new2old2,ortho1,ortho2):
   removeList=[]
   for i in range(len(blocs)):
      inDico1=0
      for sb in blocs1[i]:
         for g in range(sb[0],sb[1]+1):
            og=new2old1[g]
            for dico in dicosList1:
               if og in dico.keys():
                  inDico1=1
                  break
      inDico2=0
      for sb in blocs2[i]:
         for g in range(min(sb[0],sb[1]),max(sb[0],sb[1])+1):
            og=new2old2[g]
            for dico in dicosList2:
               if og in dico.keys():
                  inDico2=1
                  break
      if not inDico1 or not inDico2:
         removeList.append(i)
   return delBlocdeBlocs(removeList,blocs,blocs1,blocs2,ortho1,ortho2)

def rewriteDef(pathMicro,name1,name2,new2old1,new2old2):
   for name,dico in zip([name1,name2],[new2old1,new2old2]):
      defFile=open(pathMicro+name+'.predel','r')
      chNewScore=[]
      for line in defFile:
         ls=line.split()
         chNewScore.append([int(ls[1]),int(ls[5]),float(ls[6]),float(ls[7])]) 
      defFile.close()
      nbDel=0
      i=0
      while i<len(chNewScore):
         (ch,new,scPrev,scNext)= chNewScore[i]
         if new in dico:
            i+=1
         else:
            if i!=0 and chNewScore[i-1][0]==ch:
               chNewScore[i-1][3]=min(scPrev,scNext)
            if i!=len(chNewScore)-1 and chNewScore[i+1][0]==ch:
               chNewScore[i+1][2]=min(scPrev,scNext)
            del chNewScore[i]
      new2score={}
      for (ch,new,scPrev,scNext) in chNewScore:
         new2score[new]=(scPrev,scNext)
      newdef=open(pathMicro+name+'.postdel','w')
      defFile=open(pathMicro+name+'.predel','r')
      for line in defFile:
         ls=line.split()
         if int(ls[5]) in new2score:
            newdef.write(ls[0]+'\t'+ls[1]+'\t'+ls[2]+'\t'+ls[3]+'\t'+ls[4]+'\t'+ls[5]+
            '\t%.2f\t%.2f\n'%(new2score[int(ls[5])][0],new2score[int(ls[5])][1]))
      defFile.close()
      newdef.close()
  
  