import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import sys
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

df = pd.read_csv('pan_de_analysis.txt',sep="\t",index_col=False)
df['-log10(pval)'] = -1 * np.log10(df.pval.values)
df['log2(fold enrichment)'] = np.log2(df['(q/M)/(p/N)'])
df['pair'] = df['pair'].str.replace('_',' ')
name_map={'within 100kb of sig db':'A','within tad with sig db':'B','within sig dc':'C','AB':'A&B','AC':'A&C','BC':'B&C','genes overlapping with regions within 100kb of sig db':'genesA','genes overlapping with regions within sig dc':'genesC','genes overlapping with regions within tad with sig db':'genesB','genes overlapping with regions AB':'genes overlapping with A&B','genes overlapping with regions AC':'genes overlapping with A&C','genes overlapping with regions BC':'genes overlapping with B&C'}

df['sampled(M)']=df['sampled(M)'].map(name_map)
df = df.loc[~df['sampled(M)'].str.contains('&')]
df = df.loc[df['sampled(M)'].str.contains('genes')]

order=['H1 vs endoderm','ES vs NPC','ES vs CN','NPC vs CN','day0 vs day2','day0 vs day5','day0 vs day7','day0 vs day15','day0 vs day80','day2 vs day5','day2 vs day7','day2 vs day15','day2 vs day80','day5 vs day7','day5 vs day15','day5 vs day80','day7 vs day15','day7 vs day80','day15 vs day80']

figure = plt.figure(figsize=(10,2))
"""
fold_enrichment = df.pivot(index="sampled(M)",columns="pair",values="(q/M)/(p/N)")
sns.heatmap(fold_enrichment,cmap="RdBu_r",annot=True,vmin=0.5,vmax=1.5,cbar_kws={'label': 'fold enrichment ratio'},center=1,fmt=".2f")
#sns.scatterplot(data=df,x="pair",y="sampled(M)",palette="RdBu_r",hue="(q/M)/(p/N)",size="(q/M)/(p/N)",sizes=(20,200),hue_norm=(0.5,2),size_norm=(0.5,2))
plt.xticks(rotation=45,ha='right')
plt.tight_layout()
plt.savefig('pan_de_analysis_fold_enrichment.pdf')
plt.clf()
"""
neglogpval = df.pivot(index="sampled(M)",columns="pair",values="-log10(pval)")
neglogpval = neglogpval[order]
sns.heatmap(neglogpval,cmap="RdBu_r",annot=True,fmt=".1f",cbar_kws={'label':"-log10(pval)"},center=5,vmax=200)
#sns.scatterplot(data=df,x="pair",y="sampled(M)",palette="Greens",hue='-log10(pval)',size='-log10(pval)',sizes=(20,200))
plt.xticks(rotation=45,ha='right')
plt.tight_layout()
plt.savefig('pan_de_analysis_genes_neg_log_pval.pdf')
