import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats as stats
import matplotlib.image as mpimg

import matplotlib.patches as mplpatches

plt.style.use('BME163')



def parse_genome(input_file,panel_chromosome,left_bound,right_bound):
    transcript_dict={}
    protein_coding_dict={}
    start_dict={}
    for line in open(input_file):
      if 'basic' in line:
        a=line.strip().split('\t')
        if len(a)>7:
             if a[0]==panel_chromosome:
                  if a[2]=='transcript':
                       transcript_id=a[8].split('transcript_id "')[1].split('"')[0]
                       transcript_dict[transcript_id]=[]
                       start_dict[transcript_id]=[int(a[3]),int(a[4])]

    for line in open(input_file):
      if 'basic' in line:
        a=line.strip().split('\t')
        if len(a)>7:
             if a[0]==panel_chromosome:
                 if a[2] in ['exon','CDS']:
                     transcript_dict[a[8].split('transcript_id "')[1].split('"')[0]].append((a[2],a[3],a[4],a[6]))
    read_list=[]
    for transcript_id in transcript_dict:
       start_list=start_dict[transcript_id]
       start=start_list[0]
       end=start_list[1]
       plot=0
       if left_bound<start<right_bound or left_bound<end<right_bound:
          plot=1
       if start<left_bound and end>right_bound:
          plot=1
       if plot==1:

          blockstarts=[]
          blocksizes=[]
          read_heights=[]

          print(transcript_id, start, end, transcript_dict[transcript_id])

          for entry in transcript_dict[transcript_id]:
              type1=entry[0]

              blockstart=int(entry[1])
              blocksize=int(entry[2])-int(entry[1])
              read_height=height_dict[type1]
              blockstarts.append(blockstart)
              blocksizes.append(blocksize)
              read_heights.append(read_height)
              direction=entry[3]
          read_list.append([start,end,blocksizes,blockstarts,read_heights,[],direction,transcript_id,'gene_+',0,'',''])

    return sorted(read_list,key=lambda x: x[1])


def conv_dict(orig,dir):
    if orig == '+':
        direction = dir
    else:
        if dir =='+':
            direction = '-'
        else:
            direction = '+'
    return direction

def parse_reads(input_file,read_height,panel_chromosome,left_bound,right_bound,direction_file):
  direction_dict={}
  if direction_file:
      for line in open(direction_file):
          a=line.strip().split('\t')
          name=a[0]
          direction=a[1]
          direction_dict[name]=direction

  read_list=[]
  for line in open(input_file):
    a=line.strip().split('\t')
    if a[13]==panel_chromosome or 'chr'+a[13]==panel_chromosome:
        start=int(a[15])
        end=int(a[16])
        direction=a[8]
        name=a[9]
        if name.split('_')[0] in direction_dict:
            direction = conv_dict(direction_dict[name.split('_')[0]],direction)
#        else:
#            direction = '+'

        blocksizes=np.array(a[18].split(',')[:-1],dtype=int)
        blockstarts=np.array(a[20].split(',')[:-1],dtype=int)
        rSeq=np.array(a[21].split(',')[:-1],dtype=str)
        gSeq=np.array(a[22].split(',')[:-1],dtype=str)
        plot=False
        if left_bound<start<right_bound or left_bound<end<right_bound:
            plot=True
        if start<left_bound and end>right_bound:
            plot=True
        if plot:
            read_list.append([start,end,blocksizes,blockstarts,[read_height]*len(blocksizes),[],direction,name,'gene_+',0,rSeq,gSeq])


  return sorted(read_list,key=lambda x: x[1])

def plot_reads(read_list,panel,order):

  colors={}
  colors['1']=(84/255,99/255,225/255)
  colors['2']=(255/255,60/255,60/255)
  colors['gene_+']=(0/255,175/255,120/255)
  colors['gene_-']=(255/255,175/255,0/255)
#  colors['+']=(115/255,135/255,255/255)
#  colors['-']=(255/255,200/255,50/255)
  colors['+']='white'
  colors['-']=(200/255,200/255,200/255)


#  colors['1']='black'
#  colors['2']='black'
#  colors['gene_+']='black'
#  colors['gene_-']='black'
#  colors['+']='black'
#  colors['-']='black'


  colors['A']='orange'
  colors['T']='green'
  colors['C']='blue'
  colors['G']='purple'
  colors['N']='grey'



  linewidth={}
  linewidth['1']=0.02
  linewidth['2']=0.02
  linewidth['gene_+']=0.3
  linewidth['gene_-']=0.3


  y_pos_list=[]
  y_pos=0
  placed_reads=0
  length=len(read_list)

#  read_list=np.random.permutation(read_list)

  new_read_list=[]
  print(length)
  for element in read_list:
      decider=np.random.random()
      if decider<1000/length:
          new_read_list.append(element)

  reverse=False
  if order=='right':
      reverse=False
      read_list=sorted(new_read_list,key=lambda x: x[1],reverse=reverse)
  else:
      read_list=sorted(new_read_list,key=lambda x: x[0],reverse=reverse)

  length=len(read_list)
  while placed_reads<length:
      if y_pos>50:
          break
      y_pos+=1
      y_pos_list.append(y_pos)
      previous_end=0
      previous_start=100000000000000000000000000
      for read in read_list:
          plot=0
          placed=read[9]
          if placed==0:
              start=read[0]
              end=read[1]
#              if order=='right':
#                  if end<previous_start:
#                      plot=1
#              else:
              if start>previous_end:
                      plot=1
              if plot==1:
                  placed_reads+=1
                  print(placed_reads,length)

                  previous_end=end
                  previous_start=start
                  read[9]=1
                  blocksizes=read[2]
                  blockstarts=read[3]
                  read_heights=read[4]
                  mismatches=read[5]
                  direction=read[6]

                  name=read[7]
                  read_type=read[8]
                  rSeq=read[10]
                  gSeq=read[11]
                  print(len(blocksizes),len(rSeq),len(gSeq))

                  rectangle1=mplpatches.Rectangle((start,y_pos-0.05),end-start,0.1,facecolor=colors[direction],edgecolor='black',linewidth=0.02 )
                  panel.add_patch(rectangle1)
#                  panel.text(end,y_pos,name, fontsize=1)
                  for pos in range(0,len(blocksizes),1):
                      if blocksizes[pos]>1:
                          rectangle1=mplpatches.Rectangle((blockstarts[pos],y_pos-(read_heights[pos]/2)),blocksizes[pos],read_heights[pos],facecolor=colors[direction],edgecolor='black',linewidth=0.1)
                          panel.add_patch(rectangle1)
                          if len(rSeq)>0:
                              for base in range(len(rSeq[pos])):
                                  rBase=rSeq[pos][base]
                                  gBase=gSeq[pos][base]
                                  if rBase != gBase:
                                     print(blockstarts[pos]+base)
                                     rectangle1=mplpatches.Rectangle((blockstarts[pos]+base,y_pos-(read_heights[pos]/2)),1,read_heights[pos],facecolor=colors[rBase],edgecolor=colors[rBase],linewidth=0.2)
                                     panel.add_patch(rectangle1)



  try:


          y_min=0
          y_max=max(y_pos_list)+1#+max(y_pos_list)*0.1
  except:

          y_min=0
          y_max=1

  return y_min,y_max





height_dict={}
height_dict['exon']=0.25
height_dict['CDS']=0.5

targets=[#('chr17',7667160,7667360),

#('chr22',23311630,23311930),
         ('chr5',112839516,112839816),
#('chr7',55181220,55181530),
#('chr7',55198574,55198874)]

         ('chr8', 127738617,127738917),
         ('chr9', 136502290,136502590),
         ('chr11', 102110778, 102111078),
         ('chr10', 87863806,87864106),
         ('chr12', 68839159,68839459),
         ('chr17',7667110,7667410),
         ('chr12',120994164,120994464)]


#         ('chr7',55174642,55174916),
#         ('chr7',55181178,55181577),
#         ('chr7',55191702,55191941)]

for panel_chromosome,left_bound,right_bound in targets:
    print(panel_chromosome)
#    read_list1=parse_reads('/Users/vollmers/data/BWN/Enrichment/H1975.R2C2.trimmed.sorted.EGFR.pslx',0.5,panel_chromosome,left_bound,right_bound)
#    read_list2=parse_reads('/Users/vollmers/data/BWN/Enrichment/H1975.MiSeq.trimmed.sorted.EGFR.pslx',0.5,panel_chromosome,left_bound,right_bound)
    read_list3=parse_reads('/Users/vollmers/data/BWN/Enrichment/H1650.trimmed.pslx',0.5,panel_chromosome,left_bound,right_bound,'/Users/vollmers/data/BWN/Enrichment/original_direction.tsv')
    read_list4=parse_reads('/Users/vollmers/data/BWN/Enrichment/H1650.trimmed.ReadNumbers.pslx',0.5,panel_chromosome,left_bound,right_bound,'')

    read_list5=parse_genome('/Users/vollmers/data/BWN/Enrichment/lrgasp_gencode_v38_sirvs.gtf',panel_chromosome,left_bound,right_bound)

    fig_1 = plt.figure(figsize=(1,4))
    panel1 = plt.axes([0.05, 0.00 , 0.9 , 0.04],frameon=False)
#    panel2 = plt.axes([0.3, 0.25 , 0.4 , 0.19],frameon=True)
    panel3 = plt.axes([0.05, 0.05 , 0.9 , 0.45],frameon=True)
    panel4 = plt.axes([0.05, 0.52 , 0.9 , 0.45],frameon=True)
#    panel5 = plt.axes([0.3, 0.55 , 0.4 , 0.10],frameon=True)


#    panel1_y_bottom,panel1_y_pos=plot_reads(read_list1,panel1,'right')
#    panel2_y_bottom,panel2_y_pos=plot_reads(read_list2,panel2,'right')
    panel3_y_bottom,panel3_y_pos=plot_reads(read_list3,panel3,'right')
    panel4_y_bottom,panel4_y_pos=plot_reads(read_list4,panel4,'right')
#    panel5_y_bottom,panel5_y_pos=plot_reads(read_list5,panel5,'right')

#    print(panel1_y_bottom,panel1_y_pos)
#    print(panel2_y_bottom,panel2_y_pos)
    print(panel3_y_bottom,panel3_y_pos)
    print(panel4_y_bottom,panel4_y_pos)
#    print(panel5_y_bottom,panel5_y_pos)

#    panel1.set_xlim(left_bound,right_bound)
#    panel2.set_xlim(left_bound,right_bound)
    panel3.set_xlim(left_bound,right_bound)
    panel4.set_xlim(left_bound,right_bound)
#    panel5.set_xlim(left_bound,right_bound)

#    panel1.set_ylim(panel1_y_bottom,panel1_y_pos)
#    panel2.set_ylim(panel2_y_bottom,panel2_y_pos)
    panel3.set_ylim(panel3_y_bottom,panel3_y_pos)
    panel4.set_ylim(panel4_y_bottom,panel4_y_pos)
#    panel5.set_ylim(panel5_y_bottom,panel5_y_pos)

#    panel1.tick_params(bottom=False, right=False, left=False, top=False,labelbottom=False,labelleft=False,labeltop=False,labelright=False)
#    panel2.tick_params(bottom=False, right=False, left=False, top=False,labelbottom=False,labelleft=False,labeltop=False,labelright=False)
    panel3.tick_params(bottom=False, right=False, left=False, top=False,labelbottom=False,labelleft=False,labeltop=False,labelright=False)
    panel4.tick_params(bottom=False, right=False, left=False, top=False,labelbottom=False,labelleft=False,labeltop=False,labelright=False)
#    panel5.tick_params(bottom=False, right=False, left=False, top=False,labelbottom=False,labelleft=False,labeltop=False,labelright=False)

#    panel1.set_xticks([])
#    panel2.set_xticks([])
    panel3.set_xticks([])
    panel4.set_xticks([])
#    panel5.set_xticks([])

#    panel1.set_xticklabels([])
#    panel2.set_xticklabels([])
    panel3.set_xticklabels([])
    panel4.set_xticklabels([])
#    panel5.set_xticklabels([])
    panel1.set_xlim(0,1)
    panel1.set_ylim(0,1)
    panel1.set_yticks([])
    panel1.set_xticks([])
    panel1.text(0.5,1,panel_chromosome+':'+str(left_bound)+'-'+str(right_bound),fontsize=4,va='top',ha='center')
    plt.savefig(panel_chromosome+'_'+str(left_bound)+'OG_direction.png',dpi=1200)
