# -*- coding: utf-8 -*-
"""
Created on Wed Nov  2 12:05:39 2022

@author: pspea
"""
import pandas as pd
import numpy as np
from scipy import stats

import plotly.io as pio
pio.renderers.default = "browser"


import plotly.graph_objects as go
import statsmodels.stats.multitest as smt


#output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/figures/CNV_map.pdf')
copy_number_filename = ('C:/Gresham/tiny_projects/Project_Grace/relative_depth_DNA_corrected_v3.txt')
df = pd.read_table(copy_number_filename, index_col=0)
cn_dict = df.to_dict('index')

strain_list = list(cn_dict['YKR039W'].keys())
strain_list.sort()

strain_list = ['DGY1657','DGY1728','DGY1734','DGY1736','DGY1740','DGY1744','DGY1747','DGY1751']
#We need to populate the gene list with those genes that are detected in every strain - 
# otherwise there will be misalignement on between genes between strains on the global heatmap

gene_count_strain = {}

for istype in ['Obs']:
    for strain in strain_list:
        if strain != 'DGY1657':
            rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
            rnaseq_dict_file = open(rnaseq_dict_filename)
            df = pd.read_table(rnaseq_dict_filename, index_col=0)
            rnaseq_dict = df.to_dict('index')
            
            for gene in rnaseq_dict:
                if gene[0] == 'Y':
                    if gene not in gene_count_strain:
                        gene_count_strain[gene] = set()
                        
                    gene_count_strain[gene].add(strain)
                    
                    
complete_gene_list = set()
strain_max = 0

for gene in gene_count_strain:
    if len(gene_count_strain[gene]) >= strain_max:
        strain_max = len(gene_count_strain[gene])
    
    if len(gene_count_strain[gene]) >= 7:
        complete_gene_list.add(gene)
    else:
        print(gene)
                         



rnaseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/total_tpm_df.txt')
rnaseq_results_file = open(rnaseq_results_filename)
df = pd.read_table(rnaseq_results_file, index_col=0)
rnaseq_results = df.to_dict('index')

rnaseq_dict = {}

for gene in rnaseq_results:
    if gene[0] == 'Y':
        if gene not in rnaseq_dict:
            rnaseq_dict[gene] = {}
            
        for strain in strain_list:
            if strain not in rnaseq_dict[gene]:
                rnaseq_dict[gene][strain] = 0
            
            temp = []
            for replicate in range(1,4):
                rep_name = ('{}.{}').format(strain,replicate)
                
                if rep_name in rnaseq_results[gene]:
                    temp.append(rnaseq_results[gene][rep_name])
 
            rnaseq_dict[gene][strain] += np.median(temp)
            

core_list = ["YKR029C","YKR030W","YKR031C","YKR034W","YKR035W-A","YKR036C","YKR037C","YKR038C","YKR039W","YKR041W","YKR042W","YKR043C","YKR044W","YKR045C","YKR046C","YKR048C","YKR049C"] 

# core_list = ["YER133W","YER134C","YER136W","YER137C"]
# #,"YER137C-A","YER138C"

# core_list = ["YER133W","YER134C","YER136W","YER137C"]
#,"YER137C-A","YER138C"

exp_min = np.inf
exp_max = -1*np.inf

# vis_min = -1.6
# vis_max = 1.2

vis_min = -0.6
vis_max = 2.7

strain_median = {}
gene_median = {}

log2fc_dict = {}

for core_gene in core_list:
    
    core_set = set([core_gene])

    output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/core_set_expression/Log2FC_Exp_Boxplot_Core_Gene_{gene}.pdf').format(
        gene = core_gene)
    
    
    for istype in ['Obs']:
        for strain in strain_list:
            if strain != 'DGY1657':
        #strain = 'DGY1728'
                # evo_is_high = True
                # rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
                # rnaseq_dict_file = open(rnaseq_dict_filename)
                # df = pd.read_table(rnaseq_dict_filename, index_col=0)
                # rnaseq_dict = df.to_dict('index')
                
                for gene in core_set:
                    if gene in rnaseq_dict:
                        if gene not in log2fc_dict:
                            log2fc_dict[gene] = {}
                            
                        if gene in cn_dict:
                            if strain in cn_dict[gene]:
                                cn = cn_dict[gene][strain]
                                evo_value = rnaseq_dict[gene][strain]
                                anc_value = rnaseq_dict[gene]['DGY1657']
                                
                                if anc_value == 0: 
                                    anc_value = 1
                                
                                value = np.log2(evo_value/(anc_value*cn))
                                
                                
                                
                                log2fc_dict[gene][strain] = value
                                
                                # if strain not in strain_median:
                                #     strain_median[strain] = []
                                    
                                # strain_median[strain].append(value)
                                
                                # if gene not in gene_median:
                                #     gene_median[gene] = []
                                    
                                # gene_median[gene].append(value)
                                    
                                    
    
    cnv_genes_dict = {}
    noncnv_genes_dict = {}
    
    for gene in log2fc_dict:
        for strain in strain_list:
            
            if strain in log2fc_dict[gene]:
                if cn_dict[gene][strain] != 1:
                    if strain not in cnv_genes_dict:
                        cnv_genes_dict[strain] = []
                        
                    cnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
                if cn_dict[gene][strain] == 1:
                    if strain not in noncnv_genes_dict:
                        noncnv_genes_dict[strain] = []
                        
                    noncnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
    
    for strain in cnv_genes_dict:
        log2FC = cnv_genes_dict[strain]
        
        t, p = stats.ttest_1samp(log2FC, 0)
        
        print(strain, t, p)
    
    
    exp = []    
                                            
    for strain in strain_list:
        if strain in cnv_genes_dict:
            exp_sub = cnv_genes_dict[strain]
            
            for each_exp in exp_sub:
                exp_min = min(exp_min, each_exp)
                exp_max = max(exp_max, each_exp)
            
        else:
            exp_sub = [np.nan]
    
        exp.append(exp_sub)
    
    fig = go.Figure()
    
    for i in range(len(exp)):
        fig.add_trace(go.Box(y=exp[i], 
                             name=strain_list[i],
                             #marker_color = 'indianred',
                             boxpoints='all',
                             jitter=0.8,
                             pointpos=-2,
                             quartilemethod="linear"))
        fig.update_yaxes(range=[vis_min,vis_max])
    
    #fig.show()
    #fig.write_image(output_figure_name)
    
    
###

# core_list = ["YER133W","YER134C","YER136W","YER137C"]
#,"YER137C-A","YER138C"

exp_min = np.inf
exp_max = -1*np.inf

# vis_min = -1.6
# vis_max = 1.2

vis_min = -0.6
vis_max = 2.7

strain_median = {}
gene_median = {}

log2fc_dict = {}

for core_gene in core_list:
    
    core_set = set([core_gene])

    output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/Log2FC_tpm_Boxplot_Core_Gene_{gene}.pdf').format(
        gene = core_gene)
    
    
    for istype in ['Obs']:
        for strain in strain_list:
            if strain != 'DGY1657':
        #strain = 'DGY1728'
                # evo_is_high = True
                # rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
                # rnaseq_dict_file = open(rnaseq_dict_filename)
                # df = pd.read_table(rnaseq_dict_filename, index_col=0)
                # rnaseq_dict = df.to_dict('index')
                
                for gene in core_set:
                    if gene in rnaseq_dict:
                        if gene not in log2fc_dict:
                            log2fc_dict[gene] = {}
                            
                        if gene in cn_dict:
                            if strain in cn_dict[gene]:
                                cn = cn_dict[gene][strain]
                                evo_value = rnaseq_dict[gene][strain]
                                anc_value = rnaseq_dict[gene]['DGY1657']
                                
                                if anc_value == 0: 
                                    anc_value = 1
                                
                                value = np.log2(evo_value/(anc_value*cn))
                                
                                
                                
                                log2fc_dict[gene][strain] = value
                                
                                # if strain not in strain_median:
                                #     strain_median[strain] = []
                                    
                                # strain_median[strain].append(value)
                                
                                # if gene not in gene_median:
                                #     gene_median[gene] = []
                                    
                                # gene_median[gene].append(value)
                                    
                                    
    
    cnv_genes_dict = {}
    noncnv_genes_dict = {}
    
    for gene in log2fc_dict:
        for strain in strain_list:
            
            if strain in log2fc_dict[gene]:
                if cn_dict[gene][strain] != 1:
                    if strain not in cnv_genes_dict:
                        cnv_genes_dict[strain] = []
                        
                    cnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
                if cn_dict[gene][strain] == 1:
                    if strain not in noncnv_genes_dict:
                        noncnv_genes_dict[strain] = []
                        
                    noncnv_genes_dict[strain].append(log2fc_dict[gene][strain])
                    
    
    for strain in cnv_genes_dict:
        log2FC = cnv_genes_dict[strain]
        
        t, p = stats.ttest_1samp(log2FC, 0)
        
        print(strain, t, p)
    
    
    exp = []    
                                            
    for strain in strain_list:
        if strain in cnv_genes_dict:
            exp_sub = cnv_genes_dict[strain]
            
            for each_exp in exp_sub:
                exp_min = min(exp_min, each_exp)
                exp_max = max(exp_max, each_exp)
            
        else:
            exp_sub = [np.nan]
    
        exp.append(exp_sub)
    
    fig = go.Figure()
    
    for i in range(len(exp)):
        fig.add_trace(go.Box(y=exp[i], 
                             name=strain_list[i],
                             #marker_color = 'indianred',
                             boxpoints='all',
                             jitter=0.8,
                             pointpos=-2,
                             quartilemethod="linear"))
        fig.update_yaxes(range=[vis_min,vis_max])
    
    #fig.show()
    #fig.write_image(output_figure_name)
    
for strain in strain_median:
    outline = ('{strain}\n'
               'log2_median:{log2_median}\tlog2_std:{log2_std}\n'
               'median:{median}\tstd:{std}'
               ).format(
        strain = strain, 
        log2_median = round(np.median(strain_median[strain]),2), 
        log2_std = round(np.std(strain_median[strain]),2),
        median = round(2**np.median(strain_median[strain]),2), 
        std = round(2**np.std(strain_median[strain]),2)
        )
    print(outline)
    
for strain in gene_median:
    outline = ('{strain}\n'
               'log2_median:{log2_median}\tlog2_std:{log2_std}\n'
               'median:{median}\tstd:{std}'
               ).format(
        strain = strain, 
        log2_median = round(np.median(gene_median[strain]),2), 
        log2_std = round(np.std(gene_median[strain]),2),
        median = round(2**np.median(gene_median[strain]),2), 
        std = round(2**np.std(gene_median[strain]),2)
        )
    print(outline)
    
for strain in gene_median:
    outline = ('{strain}\t'
               '{log2_median} _ \t{log2_std}\n'
               ).format(
        strain = strain, 
        log2_median = round(np.median(gene_median[strain]),2), 
        log2_std = round(np.std(gene_median[strain]),2),
        )
    print(outline)
    
for core_gene in core_list:
    core_set = set([core_gene])
    
    for istype in ['Obs']:
        for strain in strain_list:
            if strain != 'DGY1657':
        #strain = 'DGY1728'
                evo_is_high = True
                rnaseq_dict_filename =  ('C:/Gresham/tiny_projects/Project_Grace/DESeq_{}_DGY1657_{}.txt').format(istype, strain)
                rnaseq_dict_file = open(rnaseq_dict_filename)
                df = pd.read_table(rnaseq_dict_filename, index_col=0)
                rnaseq_dict = df.to_dict('index')
                
                for gene in core_set:
                    if gene in rnaseq_dict:
                        if gene not in log2fc_dict:
                            log2fc_dict[gene] = {}
                            
                        if gene in cn_dict:
                            if strain in cn_dict[gene]:
                                cn = cn_dict[gene][strain]
                                
                                if rnaseq_dict[gene]['pvalue'] <= 1:
                                    print(strain, gene, rnaseq_dict[gene]['log2FoldChange'], rnaseq_dict[gene]['pvalue'])
                                    
