1

I have a DataFrame like this that I'm trying to describe with a Sankey diagram:

import pandas as pd

pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})
    animal  sex     status          count
0   dog     male    wild            8
1   cat     female  domesticated    10
2   cat     female  domesticated    11
3   dog     male    wild            14
4   cat     male    domesticated    6

I'm trying to follow the steps in the documentation but I can't make it work - I can't understand what branches where. Here's the example code:

import plotly.graph_objects as go

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A1", "A2", "B1", "B2", "C1", "C2"],
      color = "blue"
    ),
    link = dict(
      source = [0, 1, 0, 2, 3, 3], 
      target = [2, 3, 3, 4, 4, 5],
      value = [8, 4, 2, 8, 4, 2]
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

Here's what I'm trying to achieve: enter image description here

Nicolas Gervais
  • 33,817
  • 13
  • 115
  • 143

2 Answers2

3

You can create with Plotly a Sankey diagram in the following way:

import pandas as pd
import plotly.graph_objects as go

label_list = ['cat', 'dog', 'domesticated', 'female', 'male', 'wild']
# cat: 0, dog: 1, domesticated: 2, female: 3, male: 4, wild: 5
source = [0, 0, 1, 3, 4, 4]
target = [3, 4, 4, 2, 2, 5]
count = [21, 6, 22, 21, 6, 22]

fig = go.Figure(data=[go.Sankey(
    node = {"label": label_list},
    link = {"source": source, "target": target, "value": count}
    )])
fig.show()

sankey diagram How does it work: The lists source, target and count have all length 6 and the Sankey diagram has 6 arrows. The elements of source and target are the indexes of label_list. So the the first element of source is 0 which means "cat". The first element of target is 3 which means "female". The first element of count is 21. Therefore, the first arrow of the diagram goes from cat to female and has size 21. Correspondingly, the second elements of the lists source, target and count define the second arrow, etc.


Possibly you want to create a bigger Sankey diagram as in this example. Defining the source, target and count list manually then becomes very tedious. So here's a code which creates these lists from a dataframe of your format.

import pandas as pd
import numpy as np

df = pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})

categories = ['animal', 'sex', 'status']

newDf = pd.DataFrame()
for i in range(len(categories)-1):
    tempDf = df[[categories[i],categories[i+1],'count']]
    tempDf.columns = ['source','target','count']
    newDf = pd.concat([newDf,tempDf])    
newDf = newDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()

label_list = list(np.unique(df[categories].values))
source = newDf['source'].apply(lambda x: label_list.index(x))
target = newDf['target'].apply(lambda x: label_list.index(x))
count = newDf['count']
Pascalco
  • 2,481
  • 2
  • 15
  • 31
  • That was a great answer. I asked another similar [question](https://stackoverflow.com/questions/70335771/how-do-i-make-a-sankey-diagram-with-plotly-with-one-layer-that-goes-only-one-lev), feel free to answer if it you can. – Nicolas Gervais Dec 13 '21 at 13:53
0

I find parallel-categories (either px.parallel_categories or go.Parcats) to be easier to manipulate than go.Sankey, for results very similar.

This example would be:

df = pd.DataFrame({
    'animal': ['dog', 'cat', 'cat', 'dog', 'cat'],
    'sex': ['male', 'female', 'female', 'male', 'male'],
    'status': ['wild', 'domesticated', 'domesticated', 'wild', 'domesticated'],
    'count': [8, 10, 11, 14, 6]
})

fig = go.Figure(go.Parcats(
    dimensions=[
        {'label': 'animal', 'values': df['animal']},
        {'label': 'sex', 'values': df['sex']},
        {'label': 'status', 'values': df['status']},
    ],
    counts=df['count'],
))
fig.show()

Plotly

Or if your df contains individual elements (before aggregation with count), this could even be:

px.parallel_categories(df, dimensions=['animal', 'sex', 'status'])
Conchylicultor
  • 4,631
  • 2
  • 37
  • 40