# -*- coding: utf-8 -*-
"""
Created on Tue Jun  6 11:26:06 2023

@author: pspea
"""

# -*- coding: utf-8 -*-
"""
Created on Wed Nov  2 12:05:39 2022

@author: pspea
"""
import numpy as np
import pandas as pd

#import plotly.graph_objects as go

import plotly.io as pio
pio.renderers.default = "browser"


import plotly.graph_objects as go

output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/insertions_CNV_map.pdf')
copy_number_filename = ('C:/Gresham/tiny_projects/Project_Grace/relative_depth_DNA_corrected_v3.txt')
df = pd.read_table(copy_number_filename, index_col=0)
cn_dict = df.to_dict('index')

strain_list = list(cn_dict['YKR039W'].keys())
strain_list.sort()

strain_list = ['DGY1728','DGY1734','DGY1736','DGY1740','DGY1744','DGY1747','DGY1751']
#We need to populate the gene list with those genes that are detected in every strain - 
# otherwise there will be misalignement on between genes between strains on the global heatmap

gene_count_strain = {}

for istype in ['Obs']:
    for evo_strain in strain_list:
        if evo_strain != 'DGY1657':
            deseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/baySeq/DGY1657_{}_{}_bs.tab').format(evo_strain, istype)
            deseq_results_file = open(deseq_results_filename)
            df = pd.read_table(deseq_results_filename, index_col=0)
            deseq_results = df.to_dict('index')
            
            for gene in deseq_results:
                if gene[0] == 'Y':
                    if gene not in gene_count_strain:
                        gene_count_strain[gene] = set()
                        
                    gene_count_strain[gene].add(evo_strain)
                    
                    
complete_gene_list = set()
strain_max = 0

for gene in gene_count_strain:
    if len(gene_count_strain[gene]) >= strain_max:
        strain_max = len(gene_count_strain[gene])
    
    if len(gene_count_strain[gene]) >= 7:
        complete_gene_list.add(gene)
    else:
        print(gene)
                         

# for gene in complete_gene_list:
#     if gene[0] == 'Y':
#         if gene not in gene_count_strain:
#             gene_count_strain[gene] = True
            
#         for strain in strain_list:
#             if strain not in  cn_dict[gene]:
#                 gene_count_strain[gene] = False
#                 1/0
            
    

ykl_list = []
ykr_list = []
for gene in complete_gene_list:
    if 'YKL' in gene:
        ykl_list.append(gene)
    if 'YKR' in gene:
        ykr_list.append(gene)

ykr_list.sort()
ykl_list.sort(reverse=True)

yk_list = ykl_list + ykr_list

exp = []

for strain in strain_list:
    exp_sub = []
    for gene in yk_list:
        exp_sub.append(round(cn_dict[gene][strain]))
    exp.append(exp_sub)
        
        

fig = go.Figure(data=go.Heatmap(
                   z=exp,
                   x=yk_list,
                   y=strain_list,
                   hoverongaps = False,
                   colorscale= 'Reds'))
fig.show()
#fig.write_image(output_figure_name)

# colorscale= 'RdBu_r'

output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/Insertions_Log2FC_map.pdf')
log2fc_dict = {}

deseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/insertions/global_normalized_insertionPerGene.txt')
deseq_results_file = open(deseq_results_filename)
df = pd.read_table(deseq_results_filename, index_col=0)
deseq_results = df.to_dict('index')

anc_adundance = {}

for evo_strain in ['DGY1657']:        
    for gene in yk_list:
        #if deseq_results[gene]['FDR.anc_evo'] <= 0.05:
            if gene not in anc_adundance:
                anc_adundance[gene] = {}
                
                if gene in deseq_results:
                    anc_adundance[gene] = deseq_results[gene]['DGY1657']
                else:
                    anc_adundance[gene] = 0


for evo_strain in strain_list:        
    print(evo_strain)
    for gene in yk_list:
        #if deseq_results[gene]['FDR.anc_evo'] <= 0.05:
            if gene not in log2fc_dict:
                log2fc_dict[gene] = {}
                
            if gene in deseq_results:
                log2fc_dict[gene][evo_strain] = np.log2(deseq_results[gene][evo_strain] / anc_adundance[gene])
            else:
                log2fc_dict[gene][evo_strain] = 0


exp = []                                                
for strain in strain_list:
    exp_sub = []
    for gene in yk_list:
        if gene in log2fc_dict:
            if strain in log2fc_dict[gene]:
                exp_sub.append(log2fc_dict[gene][strain])
            else:
                exp_sub.append(0)
                print(gene, strain)
                1/0
    exp.append(exp_sub)
    
fig = go.Figure(data=go.Heatmap(
                   z=exp,
                   x=yk_list,
                   y=strain_list,
                   hoverongaps = False,
                   zmin=-5,
                   zmax=5,
                   colorscale= 'RdBu_r'))
fig.show()
