I am plotting a Sankey Diagram through plotly to compare different classifications of observations. However, I am having some issues with more than two classifications, where the order of observations in each classification changes between the inputs and the outputs of each node.
The code I am using is the following:
def pl_sankey(df, label_color, categories, value, title='Sankey Diagram', fname=None, width=3000, height=1600, scale=2):
from IPython.display import Image
import plotly.graph_objects as go
import pandas as pd
df = df.copy()
labels = []
colors = []
# associate labels to colors
for k, v in label_color.items():
labels += [k]
colors += [v]
# transform df into a source-target pair
st_df = None
for i in range(len(categories)-1):
_st_df = df[[categories[i],categories[i+1],value]]
_st_df.columns = ['source', 'target', 'count']
st_df = pd.concat([st_df, _st_df])
st_df = st_df.groupby(['source', 'target']).agg({'count': 'sum'}).reset_index()
# add index for source-target pair
st_df['sourceID'] = st_df['source'].apply(lambda x: labels.index(str(x)))
st_df['targetID'] = st_df['target'].apply(lambda x: labels.index(str(x)))
# creating the sankey diagram
data = dict(
type='sankey', node=dict(
pad=15, thickness=20, line = dict(color='black', width=0.5), label=labels, color=colors,
),
link=dict(source=st_df['sourceID'], target=st_df['targetID'], value=st_df['count']),
)
layout = dict(title=title, font=dict(size=16, family='Arial'))
# creating figure
fig = go.Figure(dict(data=[data], layout=layout))
if fname:
fig.write_image(f'{fname}.pdf', format='pdf', width=width, height=height, scale=scale)
return Image(fig.to_image(format='png', width=width, height=height, scale=scale))
The input parameters are:
- a pandas DataFrame
df
with groupings for each set of rows, e.g.:
# g1_l1 means group1, label1
g1 g2 g3 counts
0 g1_l1 g2_l1 g3_l1 10
1 g1_l3 g2_l2 g3_l1 1
2 g1_l1 g2_l2 g3_l2 1
3 g1_l2 g2_l2 g3_l1 40
4 g1_l2 g2_l3 g3_l2 20
5 g1_l3 g2_l1 g3_l2 10
label_color
is a dictionary, where keys are labels and values are colorscategories
are the column names of groupings, in this case['grouping1', 'grouping2', 'grouping3']
values
is the column name of counts, in this case'counts'
One example of execution is the following:
df = pd.DataFrame([
['g1_l1', 'g2_l1', 'g3_l1', 10],
['g1_l3', 'g2_l2', 'g3_l1', 1],
['g1_l1', 'g2_l2', 'g3_l2', 1],
['g1_l2', 'g2_l2', 'g3_l1', 40],
['g1_l2', 'g2_l3', 'g3_l2', 20],
['g1_l3', 'g2_l1', 'g3_l2', 10],
], columns=['g1', 'g2', 'g3', 'counts'])
label_color = {
'g1_l1': '#1f77b4', 'g1_l2': '#ff7f0e', 'g1_l3': '#279e68',
'g2_l1': '#1f77b4', 'g2_l2': '#ff7f0e', 'g2_l3': '#279e68',
'g3_l1': '#1f77b4', 'g3_l2': '#ff7f0e',
}
pl_sankey(df, label_color, categories=df.columns[:-1], value='counts', title='', fname=None)
However, this code guarantees row matching only between two adjacent columns. Consider for example, row 1:
g1 g2 g3 counts
1 g1_l3 g2_l2 g3_l1 1
Such row should start from green cluster (g1_l3
) on first column, land in orange cluster (g2_l2
) in second column and continue to blue cluster (g3_l1
) on third column. However, this is not respected in the previous plot, where input into the second column is not sorted similarly to matching output.
Attached the annotated plot to show the jumping of the observation in second column (such observation is second to last in input, but last in output in the second column):
I would like to follow the path of a row from the first to the last column. Is this possible and how to do it with Sankey diagram?