import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd

cells = ['H1','endoderm']
T = len(cells)
ctmt = ['A','B']
chros = ['chr'+str(x) for x in range(1, 23)]
df = pd.DataFrame()
for c in chros:
	cmpt = pd.read_csv('annotated/{}_compartment_annotation.txt'.format(c),sep="\t",index_col=False,usecols=cells)
	cmpt.dropna(inplace=True)
	for i in range(T-1):
		one = cells[i]
		two = cells[i+1]
		map_idx = i*2
		map1 = {'A':map_idx,'B':map_idx+1}
		map2 = {'A':map_idx+2, 'B': map_idx+3}
		c1 = cmpt[one].map(map1)
		c2 = cmpt[two].map(map2)
		temp = pd.DataFrame({'source':c1,'target':c2})
		df = pd.concat((df,temp))

print(df.groupby('source').size().reset_index(name='counts'))
df=df.groupby(['source', 'target']).value_counts().reset_index(name='counts')
df.sort_values(['source','target'],inplace=True)
print(df)

# data
label = ["A", "B"]*T
colors = ['#33bbee','#ee7733'] *T
source = df.source.values 
target = df.target.values 
value = df.counts.values 
colors_link = ['#99ddff','#ee99aa','#6699cc','#ee8866'] *(T-1)
# data to dict, dict to sankey
link = dict(source = source, target = target, value = value,color=colors_link)
node = dict(label = label,color=colors) #, pad=50, thickness=5)
data = go.Sankey(link = link, node=node)
# plot
fig = go.Figure(data)
fig.write_image("sankey_h1_endoderm.pdf")
