# -*- coding: utf-8 -*-
"""
Created on Wed Nov  2 12:05:39 2022

@author: pspea
"""
import pandas as pd
import numpy as np
from scipy import stats

import plotly.io as pio
pio.renderers.default = "browser"


import plotly.graph_objects as go
import statsmodels.stats.multitest as smt


#output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/CNV_map.pdf')
copy_number_filename = ('C:/Gresham/tiny_projects/Project_Grace/relative_depth_DNA_corrected_v3.txt')
df = pd.read_table(copy_number_filename, index_col=0)
cn_dict = df.to_dict('index')

strain_list = list(cn_dict['YKR039W'].keys())
strain_list.sort()

strain_list = ['DGY1657','DGY1728','DGY1734','DGY1736','DGY1740','DGY1744','DGY1747','DGY1751']
#We need to populate the gene list with those genes that are detected in every strain - 
# otherwise there will be misalignement on between genes between strains on the global heatmap

gene_count_strain = {}

for istype in ['Obs']:
    for strain in strain_list:
        if strain != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:
                if gene[0] == 'Y':
                    if gene not in gene_count_strain:
                        gene_count_strain[gene] = set()
                        
                    gene_count_strain[gene].add(strain)
                    
                    
complete_gene_list = set()
strain_max = 0

for gene in gene_count_strain:
    if len(gene_count_strain[gene]) >= strain_max:
        strain_max = len(gene_count_strain[gene])
    
    if len(gene_count_strain[gene]) >= 7:
        complete_gene_list.add(gene)
    else:
        print(gene)
                         

# for gene in complete_gene_list:
#     if gene[0] == 'Y':
#         if gene not in gene_count_strain:
#             gene_count_strain[gene] = True
            
#         for strain in strain_list:
#             if strain not in  cn_dict[gene]:
#                 gene_count_strain[gene] = False
#                 1/0
            
    

ykl_list = []
ykr_list = []
for gene in complete_gene_list:
    if 'YKL' in gene:
        ykl_list.append(gene)
    if 'YKR' in gene:
        ykr_list.append(gene)

ykr_list.sort()
ykl_list.sort(reverse=True)

yk_list = ykl_list + ykr_list

# exp = []

# for strain in strain_list:
#     exp_sub = []
#     for gene in yk_list:
#         exp_sub.append(round(cn_dict[gene][strain]))
#     exp.append(exp_sub)
        
        

# fig = go.Figure(data=go.Heatmap(
#                    z=exp,
#                    x=yk_list,
#                    y=strain_list,
#                    hoverongaps = False,
#                    colorscale= 'Reds'))
#fig.show()
#fig.write_image(output_figure_name)

# colorscale= 'RdBu_r'

# output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/Log2FC_map.pdf')
log2fc_dict = {}

for istype in ['Obs']:
    for strain in strain_list:
        if strain != 'DGY1657':
    #strain = 'DGY1728'
            evo_is_high = True
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in yk_list:
                #if rnaseq_dict[gene]['padj'] <= 0.05:
                    if gene not in log2fc_dict:
                        log2fc_dict[gene] = {}
                        
                    #if rnaseq_dict[gene]['padj'] <= 0.05:
                    if gene in rnaseq_dict:
                        log2fc_dict[gene][strain] = rnaseq_dict[gene]['log2FoldChange']
                    else:
                        log2fc_dict[gene][strain] = 0

# exp = []                                                
# for strain in strain_list:
#     exp_sub = []
#     for gene in yk_list:
#         if gene in log2fc_dict:
#             if strain in log2fc_dict[gene]:
#                 exp_sub.append(log2fc_dict[gene][strain])
#             else:
#                 exp_sub.append(0)
#                 print(gene, strain)
#                 1/0
#     exp.append(exp_sub)
    
# fig = go.Figure(data=go.Heatmap(
#                    z=exp,
#                    x=yk_list,
#                    y=strain_list,
#                    hoverongaps = False,
#                    zmin=-5,
#                    zmax=5,
#                    colorscale= 'RdBu_r'))
# fig.show()

# #fig.write_image(output_figure_name)

# output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/Log2FC_Boxplot.pdf')
# import plotly.graph_objects as go

# log2fc_dict = {}

# for istype in ['Obs']:
#     for strain in strain_list:
#         if strain != 'DGY1657':
#     #strain = 'DGY1728'
#             evo_is_high = True
#             rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
#             rnaseq_dict_file = open(rnaseq_dict_filename)
#             df = pd.read_table(rnaseq_dict_filename, index_col=0)
#             rnaseq_dict = df.to_dict('index')
            
#             for gene in rnaseq_dict:
#                 #if rnaseq_dict[gene]['padj'] <= 0.05:
#                     if gene not in log2fc_dict:
#                         log2fc_dict[gene] = {}
                        
#                     #if rnaseq_dict[gene]['padj'] <= 0.05:
#                     log2fc_dict[gene][strain] = rnaseq_dict[gene]['log2FoldChange']
#                     #else:
#                     #    log2fc_dict[gene][strain] = 0

# cnv_genes_dict = {}
# noncnv_genes_dict = {}

# for gene in log2fc_dict:
#     for strain in strain_list:
        
#         if strain in log2fc_dict[gene]:
#             if cn_dict[gene][strain] != 1:
#                 if strain not in cnv_genes_dict:
#                     cnv_genes_dict[strain] = []
                    
#                 cnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                
#             if cn_dict[gene][strain] == 1:
#                 if strain not in noncnv_genes_dict:
#                     noncnv_genes_dict[strain] = []
                    
#                 noncnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                
    
# exp = []                                                
# for strain in strain_list:
#     exp_sub = cnv_genes_dict[strain]

#     exp.append(exp_sub)

# fig = go.Figure()

# for i in range(len(exp)):
#     fig.add_trace(go.Box(y=exp[i], 
#                          name=strain_list[i],
#                          #marker_color = 'indianred',
#                          boxpoints='all',
#                          jitter=0.8,
#                          pointpos=-2,
#                          quartilemethod="linear"))

# fig.show()
# #fig.write_image(output_figure_name)

#output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/Log2FC_Exp_Boxplot_Core_Genes.pdf')
# import plotly.graph_objects as go
# 

# rnaseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/total_tpm_df.txt')
# rnaseq_results_file = open(rnaseq_results_filename)
# df = pd.read_table(rnaseq_results_file, index_col=0)
# rnaseq_results = df.to_dict('index')

# rnaseq_dict = {}

# for gene in rnaseq_results:
#     if gene[0] == 'Y':
#         if gene not in rnaseq_dict:
#             rnaseq_dict[gene] = {}
            
#         for strain in strain_list:
#             if strain not in rnaseq_dict[gene]:
#                 rnaseq_dict[gene][strain] = 0
            
#             temp = []
#             for replicate in range(1,4):
#                 rep_name = ('{}.{}').format(strain,replicate)
                
#                 if rep_name in rnaseq_results[gene]:
#                     temp.append(rnaseq_results[gene][rep_name])
 
#             rnaseq_dict[gene][strain] += np.median(temp)
            

# core_list = ["YKR029C","YKR030W","YKR031C","YKR034W","YKR035W-A","YKR036C","YKR037C","YKR038C","YKR039W","YKR041W","YKR042W","YKR043C","YKR044W","YKR045C","YKR046C","YKR048C","YKR049C"] 

# # core_list = ["YER133W","YER134C","YER136W","YER137C"]
# # #,"YER137C-A","YER138C"

# # core_list = ["YER133W","YER134C","YER136W","YER137C"]
# #,"YER137C-A","YER138C"

# exp_min = np.inf
# exp_max = -1*np.inf

# # vis_min = -1.6
# # vis_max = 1.2

# vis_min = -0.6
# vis_max = 2.7

# strain_median = {}
# gene_median = {}

# log2fc_dict = {}

# for core_gene in core_list:
    
#     core_set = set([core_gene])

#     output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/core_set_expression/Log2FC_Exp_Boxplot_Core_Gene_{gene}.pdf').format(
#         gene = core_gene)
    
    
#     for istype in ['Obs']:
#         for strain in strain_list:
#             if strain != 'DGY1657':
#         #strain = 'DGY1728'
#                 # evo_is_high = True
#                 # rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
#                 # rnaseq_dict_file = open(rnaseq_dict_filename)
#                 # df = pd.read_table(rnaseq_dict_filename, index_col=0)
#                 # rnaseq_dict = df.to_dict('index')
                
#                 for gene in core_set:
#                     if gene in rnaseq_dict:
#                         if gene not in log2fc_dict:
#                             log2fc_dict[gene] = {}
                            
#                         if gene in cn_dict:
#                             if strain in cn_dict[gene]:
#                                 cn = cn_dict[gene][strain]
#                                 evo_value = rnaseq_dict[gene][strain]
#                                 anc_value = rnaseq_dict[gene]['DGY1657']
                                
#                                 if anc_value == 0: 
#                                     anc_value = 1
                                
#                                 value = np.log2(evo_value/(anc_value*cn))
                                
                                
                                
#                                 log2fc_dict[gene][strain] = value
                                
#                                 # if strain not in strain_median:
#                                 #     strain_median[strain] = []
                                    
#                                 # strain_median[strain].append(value)
                                
#                                 # if gene not in gene_median:
#                                 #     gene_median[gene] = []
                                    
#                                 # gene_median[gene].append(value)
                                    
                                    
    
#     cnv_genes_dict = {}
#     noncnv_genes_dict = {}
    
#     for gene in log2fc_dict:
#         for strain in strain_list:
            
#             if strain in log2fc_dict[gene]:
#                 if cn_dict[gene][strain] != 1:
#                     if strain not in cnv_genes_dict:
#                         cnv_genes_dict[strain] = []
                        
#                     cnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
#                 if cn_dict[gene][strain] == 1:
#                     if strain not in noncnv_genes_dict:
#                         noncnv_genes_dict[strain] = []
                        
#                     noncnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
    
#     for strain in cnv_genes_dict:
#         log2FC = cnv_genes_dict[strain]
        
#         t, p = stats.ttest_1samp(log2FC, 0)
        
#         print(strain, t, p)
    
    
#     exp = []    
                                            
#     for strain in strain_list:
#         if strain in cnv_genes_dict:
#             exp_sub = cnv_genes_dict[strain]
            
#             for each_exp in exp_sub:
#                 exp_min = min(exp_min, each_exp)
#                 exp_max = max(exp_max, each_exp)
            
#         else:
#             exp_sub = [np.nan]
    
#         exp.append(exp_sub)
    
#     fig = go.Figure()
    
#     for i in range(len(exp)):
#         fig.add_trace(go.Box(y=exp[i], 
#                              name=strain_list[i],
#                              #marker_color = 'indianred',
#                              boxpoints='all',
#                              jitter=0.8,
#                              pointpos=-2,
#                              quartilemethod="linear"))
#         fig.update_yaxes(range=[vis_min,vis_max])
    
#     #fig.show()
#     #fig.write_image(output_figure_name)
    
    
# ###

# # core_list = ["YER133W","YER134C","YER136W","YER137C"]
# #,"YER137C-A","YER138C"

# exp_min = np.inf
# exp_max = -1*np.inf

# # vis_min = -1.6
# # vis_max = 1.2

# vis_min = -0.6
# vis_max = 2.7

# strain_median = {}
# gene_median = {}

# log2fc_dict = {}

# for core_gene in core_list:
    
#     core_set = set([core_gene])

#     output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/Log2FC_tpm_Boxplot_Core_Gene_{gene}.pdf').format(
#         gene = core_gene)
    
    
#     for istype in ['Obs']:
#         for strain in strain_list:
#             if strain != 'DGY1657':
#         #strain = 'DGY1728'
#                 # evo_is_high = True
#                 # rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
#                 # rnaseq_dict_file = open(rnaseq_dict_filename)
#                 # df = pd.read_table(rnaseq_dict_filename, index_col=0)
#                 # rnaseq_dict = df.to_dict('index')
                
#                 for gene in core_set:
#                     if gene in rnaseq_dict:
#                         if gene not in log2fc_dict:
#                             log2fc_dict[gene] = {}
                            
#                         if gene in cn_dict:
#                             if strain in cn_dict[gene]:
#                                 cn = cn_dict[gene][strain]
#                                 evo_value = rnaseq_dict[gene][strain]
#                                 anc_value = rnaseq_dict[gene]['DGY1657']
                                
#                                 if anc_value == 0: 
#                                     anc_value = 1
                                
#                                 value = np.log2(evo_value/(anc_value*cn))
                                
                                
                                
#                                 log2fc_dict[gene][strain] = value
                                
#                                 # if strain not in strain_median:
#                                 #     strain_median[strain] = []
                                    
#                                 # strain_median[strain].append(value)
                                
#                                 # if gene not in gene_median:
#                                 #     gene_median[gene] = []
                                    
#                                 # gene_median[gene].append(value)
                                    
                                    
    
#     cnv_genes_dict = {}
#     noncnv_genes_dict = {}
    
#     for gene in log2fc_dict:
#         for strain in strain_list:
            
#             if strain in log2fc_dict[gene]:
#                 if cn_dict[gene][strain] != 1:
#                     if strain not in cnv_genes_dict:
#                         cnv_genes_dict[strain] = []
                        
#                     cnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
#                 if cn_dict[gene][strain] == 1:
#                     if strain not in noncnv_genes_dict:
#                         noncnv_genes_dict[strain] = []
                        
#                     noncnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
    
#     for strain in cnv_genes_dict:
#         log2FC = cnv_genes_dict[strain]
        
#         t, p = stats.ttest_1samp(log2FC, 0)
        
#         print(strain, t, p)
    
    
#     exp = []    
                                            
#     for strain in strain_list:
#         if strain in cnv_genes_dict:
#             exp_sub = cnv_genes_dict[strain]
            
#             for each_exp in exp_sub:
#                 exp_min = min(exp_min, each_exp)
#                 exp_max = max(exp_max, each_exp)
            
#         else:
#             exp_sub = [np.nan]
    
#         exp.append(exp_sub)
    
#     fig = go.Figure()
    
#     for i in range(len(exp)):
#         fig.add_trace(go.Box(y=exp[i], 
#                              name=strain_list[i],
#                              #marker_color = 'indianred',
#                              boxpoints='all',
#                              jitter=0.8,
#                              pointpos=-2,
#                              quartilemethod="linear"))
#         fig.update_yaxes(range=[vis_min,vis_max])
    
#     #fig.show()
#     #fig.write_image(output_figure_name)
    
# for strain in strain_median:
#     outline = ('{strain}\n'
#                'log2_median:{log2_median}\tlog2_std:{log2_std}\n'
#                'median:{median}\tstd:{std}'
#                ).format(
#         strain = strain, 
#         log2_median = round(np.median(strain_median[strain]),2), 
#         log2_std = round(np.std(strain_median[strain]),2),
#         median = round(2**np.median(strain_median[strain]),2), 
#         std = round(2**np.std(strain_median[strain]),2)
#         )
#     print(outline)
    
# for strain in gene_median:
#     outline = ('{strain}\n'
#                'log2_median:{log2_median}\tlog2_std:{log2_std}\n'
#                'median:{median}\tstd:{std}'
#                ).format(
#         strain = strain, 
#         log2_median = round(np.median(gene_median[strain]),2), 
#         log2_std = round(np.std(gene_median[strain]),2),
#         median = round(2**np.median(gene_median[strain]),2), 
#         std = round(2**np.std(gene_median[strain]),2)
#         )
#     print(outline)
    
# for strain in gene_median:
#     outline = ('{strain}\t'
#                '{log2_median} _ \t{log2_std}\n'
#                ).format(
#         strain = strain, 
#         log2_median = round(np.median(gene_median[strain]),2), 
#         log2_std = round(np.std(gene_median[strain]),2),
#         )
#     print(outline)
    
# for core_gene in core_list:
#     core_set = set([core_gene])
    
#     for istype in ['Obs']:
#         for strain in strain_list:
#             if strain != 'DGY1657':
#         #strain = 'DGY1728'
#                 evo_is_high = True
#                 rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
#                 rnaseq_dict_file = open(rnaseq_dict_filename)
#                 df = pd.read_table(rnaseq_dict_filename, index_col=0)
#                 rnaseq_dict = df.to_dict('index')
                
#                 for gene in core_set:
#                     if gene in rnaseq_dict:
#                         if gene not in log2fc_dict:
#                             log2fc_dict[gene] = {}
                            
#                         if gene in cn_dict:
#                             if strain in cn_dict[gene]:
#                                 cn = cn_dict[gene][strain]
                                
#                                 if rnaseq_dict[gene]['pvalue'] <= 1:
#                                     print(strain, gene, rnaseq_dict[gene]['log2FoldChange'], rnaseq_dict[gene]['pvalue'])
                                    
strain_log2FC = {} 
                 
#for growth plot    
for istype in ['Obs']:
    for strain in strain_list:
        if strain != 'DGY1657':
            #evo_is_high = True
            deseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            deseq_dict_file = open(deseq_dict_filename)
            df = pd.read_table(deseq_dict_filename, index_col=0)
            deseq_dict = df.to_dict('index')
            
            #for gene in core_set:
            for gene in deseq_dict:
                if strain not in strain_log2FC:
                    strain_log2FC[strain] = 0
                    
                strain_log2FC[strain] += deseq_dict[gene]['log2FoldChange']

    
strain_log2FC

#is it significantly different distributions:

cnv_dist = []
cnn_dist = []


for istype in ['Exp']:
    for strain in strain_list:
        if strain != 'DGY1657' and strain != 'DGY1728':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:                    
                if gene in cn_dict:
                    if strain in cn_dict[gene]:
                        cn = cn_dict[gene][strain]
                        l2fc = rnaseq_dict[gene]['log2FoldChange']
                    
                        if cn > 1:
                            cnv_dist.append(l2fc)
                        
                        if cn == 1:
                            cnn_dist.append(l2fc)
                        
stats.mannwhitneyu(cnv_dist, cnn_dist)
print(np.median(cnv_dist),np.median(cnn_dist))
print(np.median(cnv_dist)/np.median(cnn_dist))


#is it significantly different distributions paired:

wilcoxon_outfile = open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/wilcoxon_outfile.tsv', 'w')
header = ('strain_cnv\tstrain_cnn\tnumb\tcnv_median\tcnn_median\tdiff\tpval\n')


wilcoxon_outfile.write(header)

#ran_batch = set()

for istype in ['Exp']:
    for strain_cnv in strain_list:
        cnv_dist = {}
        if strain_cnv != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain_cnv)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:                    
                if gene in cn_dict:
                    if strain in cn_dict[gene]:
                        cn = cn_dict[gene][strain_cnv]
                        l2fc = rnaseq_dict[gene]['log2FoldChange']
                    
                        if cn > 1:
                            if gene not in cnv_dist:
                                cnv_dist[gene] = l2fc
                            #cnv_dist.append(l2fc)
                        
            for strain_cnn in strain_list:
                uid = ('{smin}_{smax}').format(
                    smin = min(strain_cnv, strain_cnn),
                    smax = max(strain_cnv, strain_cnn)
                    )
                
                if strain_cnn != 'DGY1657' and strain_cnn != strain_cnv:
                    #ran_batch.add(uid)
                    
                    cnn_dist = {}
                    rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain_cnn)
                    rnaseq_dict_file = open(rnaseq_dict_filename)
                    df = pd.read_table(rnaseq_dict_filename, index_col=0)
                    rnaseq_dict = df.to_dict('index')
                    
                    for gene in rnaseq_dict:                    
                        if gene in cn_dict:
                            if strain in cn_dict[gene]:
                                cn = cn_dict[gene][strain_cnn]
                                l2fc = rnaseq_dict[gene]['log2FoldChange']
                            
                                if gene in cnv_dist:
                                    cnn_dist[gene] = l2fc
     
                    pair_cnv = []
                    pair_cnn = []
                    
                    for gene in cnv_dist:
                        if gene in cnn_dist:
                            pair_cnv.append(cnv_dist[gene])
                            pair_cnn.append(cnn_dist[gene])
                            
                    _w, pval = stats.wilcoxon(pair_cnv, pair_cnn)
                    
                    outline = ('{strain_cnv}\t{strain_cnn}\t{numb}'
                               '\t{cnv_median}\t{cnn_median}\t{diff}\t{pval}\n').format(
                        strain_cnv =  strain_cnv, strain_cnn = strain_cnn, numb = len(pair_cnv),
                        cnv_median = np.median(pair_cnv),
                        cnn_median = np.median(pair_cnn),
                        diff = np.median(pair_cnv)-np.median(pair_cnn),
                        pval = pval)
                                   
                    wilcoxon_outfile.write(outline)
                  

wilcoxon_outfile.close()

gene_to_cnv_lookup = {}

#is it attenuated?

attenuated_dict = {}

gene_set = (["YKL004W","YKL012W","YKL013C","YKL014C","YKL018W","YKL019W","YKL021C","YKR002W","YKR004C","YKR008W","YKR022C","YKR025W","YKR037C","YKR038C","YKR062W","YKR063C","YKR068C","YKR071C","YKR079C","YKR081C","YKR083C","YKR086W","YKL006C-A"])


for istype in ['Exp']:
    for strain in strain_list:
        if strain != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:
                if gene not in log2fc_dict:
                    log2fc_dict[gene] = {}
                    
                if gene in cn_dict:
                    if strain in cn_dict[gene]:
                        cn = cn_dict[gene][strain]
                        
                        l2fc = rnaseq_dict[gene]['log2FoldChange']
                        
                        #if l2fc < 0:
                        if gene not in attenuated_dict:
                            attenuated_dict[gene] = {}
                            
                        if strain not in attenuated_dict[gene]:
                            #l2fc = rnaseq_dict[gene]['log2FoldChange']
                            pval = rnaseq_dict[gene]['pvalue']
                            padj = rnaseq_dict[gene]['padj']
                            cn = cn_dict[gene][strain]
                                                            
                            attenuated_dict[gene][strain] = {'log2FoldChange':l2fc,
                                                             'pvalue':pval,
                                                             'padj':padj,
                                                             'cn':cn}


gene_to_cnv_lookup = {}

for gene in attenuated_dict:
    if gene not in gene_to_cnv_lookup:
        gene_to_cnv_lookup[gene] = {'cnv_low':set(),
                                    'cnn_low':set(),
                                    'cnv_high':set(),
                                    'cnn_high':set(),
                                    }
    
    for strain in attenuated_dict[gene]:
        cn = attenuated_dict[gene][strain]['cn']
        l2fc = attenuated_dict[gene][strain]['log2FoldChange']
        
        if cn > 0:
            if cn > 1:
                if l2fc < 0:
                    gene_to_cnv_lookup[gene]['cnv_low'].add(strain)
                else:
                    gene_to_cnv_lookup[gene]['cnv_high'].add(strain)
            else:
                if l2fc < 0:
                    gene_to_cnv_lookup[gene]['cnn_low'].add(strain)
                else:
                    gene_to_cnv_lookup[gene]['cnn_high'].add(strain)

#attenuated_file =  open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/DESeq_Attenuated_v3.txt', 'w')

cnv_dependant = set()
cnv_independant = set()

cnv_set = set()
cnn_set = set()

fet_dict = {'cnv_low':0,
            'cnv_high':0,
            'cnn_low':0,
            'cnn_high':0}

for gene in attenuated_dict:
    for strain in attenuated_dict[gene]:
        if strain != 'DGY1728':
            cn = attenuated_dict[gene][strain]['cn']
            l2fc = attenuated_dict[gene][strain]['log2FoldChange']
            pval = attenuated_dict[gene][strain]['pvalue']
            padj = attenuated_dict[gene][strain]['padj']
            less_than_n = len(attenuated_dict[gene])
            cnv_n_low = len(gene_to_cnv_lookup[gene]['cnv_low'])
            cnn_n_low= len(gene_to_cnv_lookup[gene]['cnn_low'])
            cnv_n_high = len(gene_to_cnv_lookup[gene]['cnv_high'])
            cnn_n_high = len(gene_to_cnv_lookup[gene]['cnn_high'])
                        
            if l2fc > 0:
                if cn > 1:
                    fet_dict['cnv_high']+=1
                else:
                    fet_dict['cnn_high']+=1
            else:
                if cn > 1:
                    fet_dict['cnv_low']+=1
                else:
                    fet_dict['cnn_low']+=1
            
                        
            outline = ('{gene}\t{strain}\t{l2fc}\t{pval}\t{padj}\t{cn}\t{less_than_n}'
                       '\t{cnv_n_low}\t{cnn_n_low}\t{cnv_n_high}\t{cnn_n_high}\n').format(
                gene = gene,
                strain = strain,
                l2fc = l2fc,
                pval = pval,
                padj = padj,
                cn = cn,
                less_than_n = less_than_n,
                cnv_n_low = cnv_n_low,
                cnn_n_low = cnn_n_low,
                cnv_n_high = cnv_n_high,
                cnn_n_high = cnn_n_high
                
                )
            
            if cnv_n_low > cnv_n_high and cnv_n_low > 1 and cnn_n_low < cnn_n_high:
                cnv_dependant.add(gene)
                #attenuated_file.write(outline)
                
            if cn > 1:
                cnv_set.add(gene)
            if cn == 1:
                cnn_set.add(gene)
                            
            if cnn_n_low > cnn_n_high and cnn_n_low > 1:
                cnv_independant.add(gene)
                
#attenuated_file.close()

len(cnv_dependant)
len(cnv_independant)

data = [[len(cnv_dependant), len(cnv_set)-len(cnv_dependant)],
        [len(cnv_independant), len(cnn_set)-len(cnv_independant)]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)


print(fet_dict['cnv_low'], fet_dict['cnv_high'])
                                
print(fet_dict['cnn_low'], fet_dict['cnn_high'])


data = [[fet_dict['cnv_low'], fet_dict['cnv_high']],
        [fet_dict['cnn_low'], fet_dict['cnn_high']]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)


#calc essential genes
data = [[12, 15-12],
        [42, 111-42]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[14, 22],
        [17, 322]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[5, 8],
        [9, 92]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[8, 15],
        [13, 88]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)


data = [[5, 8],
        [9, 92]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)


data = [[4, 13],
        [8, 285]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[4, 14],
        [8, 156]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

        

### Essential gene attentuation
attenuated_dict = {}

gene_set = set(["YKL004W","YKL012W","YKL013C","YKL014C","YKL018W","YKL019W","YKL021C","YKR002W","YKR004C","YKR008W","YKR022C","YKR025W","YKR037C","YKR038C","YKR062W","YKR063C","YKR068C","YKR071C","YKR079C","YKR081C","YKR083C","YKR086W","YKL006C-A"])

for istype in ['Exp']:
    for strain in strain_list:
        if strain != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in gene_set:                    
                if gene in cn_dict:
                    if strain in cn_dict[gene]:
                        cn = cn_dict[gene][strain]
                        
                        l2fc = rnaseq_dict[gene]['log2FoldChange']
                        
                        #if l2fc < 0:
                        if gene not in attenuated_dict:
                            attenuated_dict[gene] = {}
                            
                        if strain not in attenuated_dict[gene]:
                            #l2fc = rnaseq_dict[gene]['log2FoldChange']
                            pval = rnaseq_dict[gene]['pvalue']
                            padj = rnaseq_dict[gene]['padj']
                            cn = cn_dict[gene][strain]
                                                            
                            attenuated_dict[gene][strain] = {'log2FoldChange':l2fc,
                                                             'pvalue':pval,
                                                             'padj':padj,
                                                             'cn':cn}


gene_to_cnv_lookup = {}

for gene in attenuated_dict:
    if gene not in gene_to_cnv_lookup:
        gene_to_cnv_lookup[gene] = {'cnv_low':set(),
                                    'cnn_low':set(),
                                    'cnv_high':set(),
                                    'cnn_high':set(),
                                    }
    
    for strain in attenuated_dict[gene]:
        cn = attenuated_dict[gene][strain]['cn']
        l2fc = attenuated_dict[gene][strain]['log2FoldChange']
        
        if cn > 0:
            if cn > 1:
                if l2fc < 0:
                    gene_to_cnv_lookup[gene]['cnv_low'].add(strain)
                else:
                    gene_to_cnv_lookup[gene]['cnv_high'].add(strain)
            else:
                if l2fc < 0:
                    gene_to_cnv_lookup[gene]['cnn_low'].add(strain)
                else:
                    gene_to_cnv_lookup[gene]['cnn_high'].add(strain)

#attenuated_file =  open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/DESeq_Attenuated_v3.txt', 'w')

cnv_dependant = set()
cnv_independant = set()

cnv_set = set()
cnn_set = set()

fet_dict = {'cnv_low':0,
            'cnv_high':0,
            'cnn_low':0,
            'cnn_high':0}

for gene in attenuated_dict:
    for strain in attenuated_dict[gene]:
        if strain != 'DGY1728':
            cn = attenuated_dict[gene][strain]['cn']
            l2fc = attenuated_dict[gene][strain]['log2FoldChange']
            pval = attenuated_dict[gene][strain]['pvalue']
            padj = attenuated_dict[gene][strain]['padj']
            less_than_n = len(attenuated_dict[gene])
            cnv_n_low = len(gene_to_cnv_lookup[gene]['cnv_low'])
            cnn_n_low= len(gene_to_cnv_lookup[gene]['cnn_low'])
            cnv_n_high = len(gene_to_cnv_lookup[gene]['cnv_high'])
            cnn_n_high = len(gene_to_cnv_lookup[gene]['cnn_high'])
                        
            if l2fc > 0:
                if cn > 1:
                    fet_dict['cnv_high']+=1
                else:
                    fet_dict['cnn_high']+=1
            else:
                if cn > 1:
                    fet_dict['cnv_low']+=1
                else:
                    fet_dict['cnn_low']+=1
            
                        
            outline = ('{gene}\t{strain}\t{l2fc}\t{pval}\t{padj}\t{cn}\t{less_than_n}'
                       '\t{cnv_n_low}\t{cnn_n_low}\t{cnv_n_high}\t{cnn_n_high}\n').format(
                gene = gene,
                strain = strain,
                l2fc = l2fc,
                pval = pval,
                padj = padj,
                cn = cn,
                less_than_n = less_than_n,
                cnv_n_low = cnv_n_low,
                cnn_n_low = cnn_n_low,
                cnv_n_high = cnv_n_high,
                cnn_n_high = cnn_n_high
                
                )
            
            if cnv_n_low > cnv_n_high:
                #and cnn_n_low < cnn_n_high:
                cnv_dependant.add(gene)
                #attenuated_file.write(outline)
                
            if cn > 1:
                cnv_set.add(gene)
            if cn == 1:
                cnn_set.add(gene)
                            
            if cnn_n_low > cnn_n_high and cnv_n_low < cnv_n_high:
                cnv_independant.add(gene)
                
#attenuated_file.close()

print(fet_dict['cnv_low'], fet_dict['cnv_high'])
                                
print(fet_dict['cnn_low'], fet_dict['cnn_high'])


data = [[fet_dict['cnv_low'], fet_dict['cnv_high']],
        [fet_dict['cnn_low'], fet_dict['cnn_high']]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

        
data = [[12, 8], [6,16]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[37, 3], [125,(167-125)]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)

data = [[7, 32], [318,1121]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)
#DSG_tpm

rnaseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/total_tpm_df.txt')
rnaseq_results_file = open(rnaseq_results_filename)
df = pd.read_table(rnaseq_results_file, index_col=0)
rnaseq_results = df.to_dict('index')


ancseq_dict = {}
strain = 'DGY1657'

for gene in rnaseq_results:
    if gene[0] == 'Y':
        if gene not in ancseq_dict:
            ancseq_dict[gene] = 0
            
            temp = []
            for replicate in range(1,4):
                rep_name = ('{}.{}').format(strain,replicate)
                
                if rep_name in rnaseq_results[gene]:
                    temp.append(rnaseq_results[gene][rep_name])
                    
            ancseq_dict[gene] = np.median(temp)
                

rnaseq_dict = {}

gene_dsg_set = set()
for gene in rnaseq_results:
    if gene[0] == 'Y':
        if gene not in rnaseq_dict:
            rnaseq_dict[gene] = {}
            
        for strain in strain_list:
            #if strain != 'DGY1657':
                #strain != 'DGY1728'
                if strain not in rnaseq_dict[gene]:
                    rnaseq_dict[gene][strain] = []
                    
                cn = max(cn_dict[gene][strain],1)
                
                if cn > 1:
                    gene_dsg_set.add(gene)
                
                for replicate in range(1,4):
                    rep_name = ('{}.{}').format(strain,replicate)
                    
                    if rep_name in rnaseq_results[gene]:
                        if cn*ancseq_dict[gene]>0:
                            rel_cn_tpm = rnaseq_results[gene][rep_name]/(cn*ancseq_dict[gene])
                            
                            rnaseq_dict[gene][strain].append(rel_cn_tpm)

tpm_mwu_outfile = open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/tpm_mwu_outfile.tsv', 'w')
header = ('gene\tratio_median_cnv_cnn\tadj.pvalue\n')           

gene_list = []
diff_list = []
pval_list = []

tpm_mwu_sig = {}

for gene in gene_dsg_set:
    cnv_tpm = []
    cnn_tpm = []
    
    for strain in rnaseq_dict[gene]:
        cn = cn_dict[gene][strain]
        
        tpm_list = rnaseq_dict[gene][strain]
        
        if cn == 1:
            cnn_tpm += tpm_list
            
        if cn >=2:
            cnv_tpm += tpm_list
            
    if len(cnv_tpm) > 0 and len(cnn_tpm) > 0:
        _s, pval = stats.mannwhitneyu(cnv_tpm, cnn_tpm)
        diff = (np.median(cnv_tpm)/np.median(cnn_tpm))
        
        gene_list.append(gene)
        diff_list.append(diff)
        pval_list.append(pval)
        

        outline = ('{gene}\t{diff}\t{pval}\n').format(
            gene = gene, diff = diff, pval = pval)
        
        tpm_mwu_outfile.write(outline)
        
        
        if pval <= 0.05:
            tpm_mwu_sig[gene] = diff
            
tpm_mwu_outfile.close()

for gene in rnaseq_dict:
    if gene in tpm_mwu_sig:
        if gene[0:3] == 'YKR':
            print(gene)
            print(rnaseq_dict[gene])
            print(tpm_mwu_sig[gene])


#TPM essential enrichment
data = [[6, 25],
        [54, 179]]
stat, pval = stats.fisher_exact(data)
print('FET_statistic p-value')
print(stat, pval)
        
bool_results, adj_pval_list = smt.fdrcorrection(pval_list)

len([x for x in pval_list if x <= 0.05])
    
for i in range(len(gene_list)):
    gene = gene_list[i]
    diff = diff_list[i]
    
    adj_pval = adj_pval_list[i]
    
    if adj_pval <= 0.05:
        print(gene, diff, adj_pval)
        
#         outline = ('{gene}\t{diff}\t{adj_pval}\n').format(
#             gene = gene, diff = diff, adj_pval = adj_pval)
        
#         tpm_mwu_outfile.write(outline)
# tpm_mwu_outfile.close()
            

gene = 'YKL119C'
attenuated_dict ={}

gene_dsg_set = set()

gene_list = (['YKL166C','YKL042W'])

for istype in ['Exp']:
    for strain in strain_list:
        if strain != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:
                if gene not in log2fc_dict:
                    log2fc_dict[gene] = {}
                    
                if gene in cn_dict:
                    if strain in cn_dict[gene]:
                        cn = cn_dict[gene][strain]
                        
                        l2fc = rnaseq_dict[gene]['log2FoldChange']
                        
                        #if l2fc < 0:
                        if gene not in attenuated_dict:
                            attenuated_dict[gene] = {}
                            
                        if strain not in attenuated_dict[gene]:
                            #l2fc = rnaseq_dict[gene]['log2FoldChange']
                            pval = rnaseq_dict[gene]['pvalue']
                            padj = rnaseq_dict[gene]['padj']
                            cn = cn_dict[gene][strain]
                            
                            if cn > 1:
                                gene_dsg_set.add(gene)
                                                            
                            attenuated_dict[gene][strain] = {'log2FoldChange':l2fc,
                                                             'pvalue':pval,
                                                             'padj':padj,
                                                             'cn':cn}
                            


for gene in gene_dsg_set:
    cnv_l2fc = []
    cnn_l2fc = []
    
    for strain in attenuated_dict[gene]:
        l2fc = attenuated_dict[gene][strain]['log2FoldChange']
        
        if attenuated_dict[gene][strain]['cn'] == 1:
            cnn_l2fc.append(l2fc)
            
        if attenuated_dict[gene][strain]['cn'] > 1:
            cnv_l2fc.append(l2fc)
            
    if len(cnv_l2fc) > 0 and len(cnn_l2fc) > 0:
        _s, pval = stats.mannwhitneyu(cnv_l2fc, cnn_l2fc)
        diff = abs(2**(np.median(cnv_l2fc) - np.median(cnn_l2fc)))
        print(gene, pval)
        
        if pval <= 0.05:
        
            print(gene)
            print(2**np.median(cnn_l2fc), 2**np.median(cnv_l2fc), 2**(np.median(cnv_l2fc) - np.median(cnn_l2fc)))
            
        if abs(2**(np.median(cnv_l2fc) - np.median(cnn_l2fc))) > 10:
            
            print(gene)
            print(2**np.median(cnn_l2fc), 2**np.median(cnv_l2fc), 2**(np.median(cnv_l2fc) - np.median(cnn_l2fc)))

        
    
#Are the duplicate expression correlated with fitness?

rnaseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/total_tpm_df.txt')
rnaseq_results_file = open(rnaseq_results_filename)
df = pd.read_table(rnaseq_results_file, index_col=0)
rnaseq_results = df.to_dict('index')

tpm_per_gene = {}

for strain in strain_list:
    for gene in cn_dict:
        if (strain in cn_dict[gene]) and gene in rnaseq_results:
            for rep in range(1,4):
                rep_name = ('{strain}.{rep}').format(strain = strain, rep = rep)
                rep_temp = []
                
                if rep_name in rnaseq_results[gene]:
                    rep_temp.append(rnaseq_results[gene][rep_name])
                    
                if gene not in tpm_per_gene:
                    tpm_per_gene[gene] = {}
                    
                if strain not in tpm_per_gene[gene]:
                    tpm_per_gene[gene][strain] = np.median(rep_temp)
                        
strain_difference_dict = {}
strain_cnv_differ_dict = {}

for strain in strain_list:
    if strain != 'DGY1657':     

        if strain not in strain_difference_dict:
            strain_difference_dict[strain] = 0
            
        if strain not in strain_cnv_differ_dict:
            strain_cnv_differ_dict[strain] = 0
                   
        for gene in tpm_per_gene:
            diff = tpm_per_gene[gene][strain] - tpm_per_gene[gene]['DGY1657']
            
            strain_difference_dict[strain] += diff
            
            if cn_dict[gene][strain] > 1:
                strain_cnv_differ_dict[strain] += diff
            
tpm_exp_outfile = open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/cnv_expression_difference.tsv', 'w')
header = ('strain\ttpm_diff\n')
tpm_exp_outfile.write(header)
 
for strain in strain_cnv_differ_dict:
    outline = ('{}\t{}\n').format(strain, strain_cnv_differ_dict[strain])
    tpm_exp_outfile.write(outline)
    
tpm_exp_outfile.close()