import numpy as np
import pandas as pd
    
from scipy import stats

#import plotly.graph_objects as go

import plotly.io as pio
pio.renderers.default = "browser"


import plotly.graph_objects as go

output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/insertions_CNV_map.pdf')
copy_number_filename = ('C:/Gresham/tiny_projects/Project_Grace/relative_depth_DNA_corrected_v3.txt')
df = pd.read_table(copy_number_filename, index_col=0)
cn_dict = df.to_dict('index')

strain_list = list(cn_dict['YKR039W'].keys())
strain_list.sort()

strain_list = ['DGY1728','DGY1734','DGY1736','DGY1740','DGY1744','DGY1747','DGY1751']
                      
essential_filename =  ('C:/Gresham/tiny_projects/Project_Grace/insertions/Essential_ORFs.txt')
essential_file = open(essential_filename)

essential_set = set()

for line in essential_file:
    line = line.strip()
    essential_set.add(line)

deseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/insertions/global_normalized_insertionPerGene.txt')
#deseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/insertions/global_both_eu_normalized_insertionPerGene.txt')
deseq_results_file = open(deseq_results_filename)
df = pd.read_table(deseq_results_filename, index_col=0)
deseq_results = df.to_dict('index')

complete_gene_list = set()

for gene in deseq_results:
    complete_gene_list.add(gene)

insert_strain_list = list(deseq_results['YKR039W'].keys())
insert_strain_list.sort()

insert_dict = {}

for strain in insert_strain_list:        
    #print(evo_strain)
    for gene in complete_gene_list:
        #if deseq_results[gene]['FDR.anc_evo'] <= 0.05:
            if gene not in insert_dict:
                insert_dict[gene] = {}
                                           
            if gene in deseq_results:
                if not np.isnan(deseq_results[gene][strain]):
                    insert_dict[gene][strain] = deseq_results[gene][strain] 
                    
sorted_genes_dict = {}

non_cnv_sorted_genes_dict = {}

for gene in insert_dict:
    if gene in essential_set:
        istype = 'essential'
    else:
        istype = 'non'
                
    for strain in insert_strain_list:
        
        if strain not in sorted_genes_dict:
            sorted_genes_dict[strain] = {'essential':{}, 'non':{}}
            
        if strain not in non_cnv_sorted_genes_dict:
            non_cnv_sorted_genes_dict[strain] = {'essential':{}, 'non':{}}
        
        # if strain not in non_cnv_sorted_genes_dict:
        #     non_cnv_sorted_genes_dict[strain] = {'essential':[], 'non':[]}
        
        if 'DGY1657' in strain:
            if strain in insert_dict[gene]:
                sorted_genes_dict[strain][istype][gene] = (insert_dict[gene][strain])
                #non_cnv_sorted_genes_dict[strain][istype].append(insert_dict[gene][strain])
        
        else:
            if strain in insert_dict[gene]:
                if cn_dict[gene][strain] != 1:                        
                    sorted_genes_dict[strain][istype][gene] = (insert_dict[gene][strain])
                else:
                    non_cnv_sorted_genes_dict[strain][istype][gene] = (insert_dict[gene][strain])


output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/Fig2B_Insertion_Essential_CNV_versus_CNV_Boxplot.pdf')

ess_dict = {}
         
for strain in sorted_genes_dict:
    if strain != 'DGY1657':
        non_strain_cnv = []
        non_strain_cnn = []
        
        ess_strain_cnv = []
        ess_strain_cnn = []
        
        for gene in sorted_genes_dict[strain]['non']:
            #non_strain_cnv.append(sorted_genes_dict[strain]['non'][gene])
            non_strain_cnv.append(sorted_genes_dict[strain]['non'][gene] / cn_dict[gene][strain])
                
        for gene in sorted_genes_dict[strain]['essential']:
            #ess_strain_cnv.append(sorted_genes_dict[strain]['essential'][gene])
            ess_strain_cnv.append(sorted_genes_dict[strain]['essential'][gene] / cn_dict[gene][strain])

        for gene in non_cnv_sorted_genes_dict[strain]['non']:
            #non_strain_cnn.append(non_cnv_sorted_genes_dict[strain]['non'][gene])
            non_strain_cnn.append(non_cnv_sorted_genes_dict[strain]['non'][gene] / cn_dict[gene][strain])
            
                
        for gene in non_cnv_sorted_genes_dict[strain]['essential']:
            #ess_strain_cnn.append(non_cnv_sorted_genes_dict[strain]['essential'][gene])
            ess_strain_cnn.append(non_cnv_sorted_genes_dict[strain]['essential'][gene] / cn_dict[gene][strain])
         
        ess_dict[strain] = {'cnv':[], 'cnn':[]}
        ess_dict[strain]['cnv'] = ess_strain_cnv
        ess_dict[strain]['cnn'] = ess_strain_cnn
        
        _w, pval = stats.mannwhitneyu(ess_strain_cnv, ess_strain_cnn)
        
        outline = ('{strain}\t{mean}\t{pval}\n').format(strain=strain,
                                                        mean = (np.mean(ess_strain_cnv)/np.mean(ess_strain_cnn)),
                                                        pval = pval)
        print(outline)
        
fig = go.Figure()



for strain in insert_strain_list:
    if strain in ess_dict:
        for istype in ['cnn','cnv']:
            fig.add_trace(go.Box(y=np.log2(ess_dict[strain][istype]), 
                                 name=strain+'_'+istype,
                                 #marker_color = 'indianred',
                                 boxpoints='all',
                                 jitter=0.5,
                                 pointpos=-2,
                                 marker_size=2,
                                 line_width=1,
                                 quartilemethod="linear"))
            
#fig.update_yaxes(range=[0, 1000])
        
fig.show()

fig.write_image(output_figure_name)
    

from scipy.stats import mannwhitneyu
    
for strain in sorted_genes_dict:
    non = sorted_genes_dict[strain]['non']
    essential = sorted_genes_dict[strain]['essential']
    
    U1, p = mannwhitneyu(non, essential)
    
    print(strain, p)
    
    
    
for strain in sorted_genes_dict:
    non = non_cnv_sorted_genes_dict[strain]['non']
    essential = non_cnv_sorted_genes_dict[strain]['essential']
    
    U1, p = mannwhitneyu(non, essential)
    
    print(strain, p)