# -*- coding: utf-8 -*-
"""
Created on Tue Jun  6 13:09:44 2023

@author: pspea
"""

import plotly.io as pio
pio.renderers.default = "browser"

import plotly.graph_objects as go

import numpy as np
import pandas as pd 
from scipy import stats

strain_list = ['DGY1657','DGY1728','DGY1734','DGY1736','DGY1740','DGY1744','DGY1747','DGY1751']
rnaseq_results_filename =  ('C:/Gresham/tiny_projects/Project_Grace/total_tpm_df.txt')
rnaseq_results_file = open(rnaseq_results_filename)
df = pd.read_table(rnaseq_results_file, index_col=0)
rnaseq_results = df.to_dict('index')

copy_number_filename = ('C:/Gresham/tiny_projects/Project_Grace/relative_depth_DNA_corrected_v3.txt')
df = pd.read_table(copy_number_filename, index_col=0)
cn_dict = df.to_dict('index')


ancseq_dict = {}
strain = 'DGY1657'

for gene in rnaseq_results:
    if gene[0] == 'Y':
        if gene not in ancseq_dict:
            ancseq_dict[gene] = 0
            
            temp = []
            for replicate in range(1,4):
                rep_name = ('{}.{}').format(strain,replicate)
                
                if rep_name in rnaseq_results[gene]:
                    temp.append(rnaseq_results[gene][rep_name])
                    
            ancseq_dict[gene] = np.median(temp)
                

rnaseq_dict = {}

strain_tpm_abundance = {}

gene_cnv_set = set()
gene_cnn_set = set()
for gene in rnaseq_results:
    if gene[0] == 'Y':
        if gene not in rnaseq_dict:
            rnaseq_dict[gene] = {}
            
        for strain in strain_list:
                if strain not in strain_tpm_abundance:
                    strain_tpm_abundance[strain] = []
            #if strain != 'DGY1657':
                #strain != 'DGY1728'
                if strain not in rnaseq_dict[gene]:
                    rnaseq_dict[gene][strain] = []
                    
                cn = max(cn_dict[gene][strain],1)
                
                if cn > 1:
                    gene_cnv_set.add(gene)
                else:
                    gene_cnn_set.add(gene)
                
                for replicate in range(1,4):
                    rep_name = ('{}.{}').format(strain,replicate)
                    
                    if rep_name in rnaseq_results[gene]:
                        if cn*ancseq_dict[gene]>0:
                            
                            rel_cn_tpm = rnaseq_results[gene][rep_name]/(cn*ancseq_dict[gene])
                                
                            rnaseq_dict[gene][strain].append(rel_cn_tpm)
                            
                            strain_tpm_abundance[strain].append(rnaseq_results[gene][rep_name])

###
common_core = set(['YKR029C','YKR030W','YKR031C','YKR034W','YKR035W-A','YKR036C','YKR037C','YKR038C','YKR039W','YKR041W','YKR042W','YKR043C','YKR044W','YKR045C','YKR046C','YKR048C','YKR049C'])
for gene in common_core:
    print(gene)
    for strain in strain_list:
        print(strain)
        print(round(np.median(rnaseq_dict[gene][strain])),round(np.percentile(strain_tpm_abundance[strain], 50)))

tpm_mwu_outfile = open('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/tpm_mwu_outfile.tsv', 'w')
header = ('gene\tratio_median_cnv_cnn\tadj.pvalue\n')           

gene_list = []
diff_list = []
pval_list = []

tpm_mwu_sig = {}

for gene in gene_cnv_set:
    cnv_tpm = []
    cnn_tpm = []
    
    for strain in rnaseq_dict[gene]:
        if strain != 'DGY1657':
            cn = cn_dict[gene][strain]
            
            tpm_list = rnaseq_dict[gene][strain]
            
            if cn == 1:
                cnn_tpm += tpm_list
                
            if cn >=2:
                cnv_tpm += tpm_list
                
    if len(cnv_tpm) > 0 and len(cnn_tpm) > 0:
        _s, pval = stats.mannwhitneyu(cnv_tpm, cnn_tpm)
        diff = (np.median(cnv_tpm)/np.median(cnn_tpm))
        
        gene_list.append(gene)
        diff_list.append(diff)
        pval_list.append(pval)
        
        outline = ('{gene}\t{diff}\t{pval}\n').format(
            gene = gene, diff = diff, pval = pval)
        
        tpm_mwu_outfile.write(outline)
        
        
        if pval <= 0.05:
            tpm_mwu_sig[gene] = diff
            
tpm_mwu_outfile.close()

for gene in rnaseq_dict:
    if gene in tpm_mwu_sig:
        if gene[0:3] == 'YKR':
            exp_min = 0
            exp_max = 0
            print(gene)
            print(rnaseq_dict[gene])
            print(tpm_mwu_sig[gene])
            
            output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/Log2FC_Exp_Boxplot_Core_Gene_{gene}.pdf').format(
                gene = gene)
            
            strain_list = list(rnaseq_dict[gene].keys())
                        
            exp = []
            strain_cn = []
                   
            for strain in strain_list:
                if strain != 'DGY1657':
                    strain_sub = strain + '_' + str(round(cn_dict[gene][strain]))
                    
                    if strain in rnaseq_dict[gene]:
                        exp_sub = rnaseq_dict[gene][strain]
                        
                        
                        for each_exp in exp_sub:
                            exp_min = min(exp_min, each_exp)
                            exp_max = max(exp_max, each_exp)
                        
                    else:
                        exp_sub = [np.nan]
                        
                
                    exp.append(exp_sub)
                    strain_cn.append(strain_sub)
            
            fig = go.Figure()
            
            for i in range(len(exp)):
                fig.add_trace(go.Box(y=exp[i], 
                                     name=strain_cn[i],
                                     boxpoints='all',
                                     jitter=0.8,
                                     pointpos=-2,
                                     quartilemethod="linear"))
                
                fig.update_yaxes(range=[exp_min*0.9,exp_max*1.1])
                fig.update_layout(title=gene)
                
            #fig.show()
            fig.write_image(output_figure_name)
            
for gene in rnaseq_dict:
    if gene in tpm_mwu_sig:
        if gene[0:3] == 'YKR':
            exp_min = 0
            exp_max = 0
            print(gene)
            print(rnaseq_dict[gene])
            print(tpm_mwu_sig[gene])
            
            output_figure_name = ('C:/Gresham/tiny_projects/Project_Grace/Review_GB/attenuation/Log2FC_Exp_Boxplot_Core_Gene_binned_{gene}.pdf').format(
                gene = gene)
            
            strain_list = list(rnaseq_dict[gene].keys())
                        
            CNV_exp = []
            CNN_exp = []
            strain_cn = []
                   
            for strain in strain_list:
                if strain != 'DGY1657':
                    cn = cn_dict[gene][strain]
                    if cn > 1:
                        CNV_exp += rnaseq_dict[gene][strain]
                    else:
                        CNN_exp += rnaseq_dict[gene][strain]
                        
            for each_exp in CNV_exp:
                exp_min = min(exp_min, each_exp)
                exp_max = max(exp_max, each_exp)
                
            for each_exp in CNN_exp:
                exp_min = min(exp_min, each_exp)
                exp_max = max(exp_max, each_exp)
            
            fig = go.Figure()
            
            fig.add_trace(go.Box(y=CNN_exp, 
                                 name='CNN',
                                 boxpoints='all',
                                 jitter=0.8,
                                 pointpos=-2,
                                 quartilemethod="linear"))
            
            fig.add_trace(go.Box(y=CNV_exp, 
                                 name='CNV',
                                 boxpoints='all',
                                 jitter=0.8,
                                 pointpos=-2,
                                 quartilemethod="linear"))
            
            fig.update_yaxes(range=[exp_min*0.9,exp_max*1.1])
            fig.update_layout(title=gene)
                
            fig.show()
            fig.write_image(output_figure_name)

