1

My seaborn plot is shown below. Is there a way to add the info in the flag column (which will always be a single character or empty string) in the center (or top) of the bars? Hoping there is an answer which would not need redoing the plot as well.

This answer seems to have some pointers but I am not sure how to connect it back to the original dataframe to pull info in the flag column.

import matplotlib.pyplot as plt
import seaborn as sns

df = pd.DataFrame([
    ['C', 'G1', 'gbt',    'auc', 0.7999, "†"],
    ['C', 'G1', 'gbtv2',  'auc', 0.8199, "*"],
    ['C', 'G1', 'gbt',  'pr@2%', 0.0883, "*"],
    ['C', 'G1', 'gbt', 'pr@10%', 0.0430,  ""],
    ['C', 'G2', 'gbt',    'auc', 0.7554,  ""],
    ['C', 'G2', 'gbt',  'pr@2%', 0.0842,  ""],
    ['C', 'G2', 'gbt', 'pr@10%', 0.0572,  ""],
    ['C', 'G3', 'gbt',    'auc', 0.7442,  ""],
    ['C', 'G3', 'gbt',  'pr@2%', 0.0894,  ""],
    ['C', 'G3', 'gbt', 'pr@10%', 0.0736,  ""],
    ['E', 'G1', 'gbt',    'auc', 0.7988,  ""],
    ['E', 'G1', 'gbt',  'pr@2%', 0.0810,  ""],
    ['E', 'G1', 'gbt', 'pr@10%', 0.0354,  ""],
    ['E', 'G1', 'gbtv3','pr@10%',0.0454,  ""],
    ['E', 'G2', 'gbt',    'auc', 0.7296,  ""],
    ['E', 'G2', 'gbt',  'pr@2%', 0.1071,  ""],
    ['E', 'G2', 'gbt', 'pr@10%', 0.0528,  "†"],
    ['E', 'G3', 'gbt',    'auc', 0.6958,  ""],
    ['E', 'G3', 'gbt',  'pr@2%', 0.1007,  ""],
    ['E', 'G3', 'gbt', 'pr@10%', 0.0536,  "†"],
  ], columns=["src","grp","model","metric","val","flag"])

cat = sns.catplot(data=df, x="grp", y="val", hue="model", kind="bar", sharey=False, 
            col="metric", row="src")
plt.show()
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
ironv
  • 978
  • 10
  • 25

1 Answers1

2
  • The issue is, for each axes, and each container within a given axes, the corresponding data must be selected. For example:
    • The first facet has src = C and metric = auc, and the facet is comprised of 3 containers, corresponding to the unique values of 'model'.
  • The label= parameter in .bar_label expects a list with the same number of values as there are ticks on the x-axis, even if a bar doesn't exist in that space.
    • The list-comprehension labels = [...] puts the corresponding label at the correct index, and fills missing labels with ''.
  • Tested in python 3.11.2, pandas 2.0.0, matplotlib 3.7.1, seaborn 0.12.2
import pandas as pd
import seaborn as sns
import numpy as np

# plot the dataframe from the OP
g = sns.catplot(data=df, x="grp", y="val", hue="model", kind="bar", sharey=False, col="metric", row="src")

# get the unique values from the grp column, which corresponds to the x-axis tick labels
grp_unique = df.grp.unique()

# iterate through axes
for ax in g.axes.flat:
    
    # get the components of the title to filter the current data
    src, metric = [s.split(' = ')[1] for s in ax.get_title().split(' | ')]
    
    # iterate through the containers of the current axes
    for c in ax.containers:
        
        # get the hue label of the current container
        model = c.get_label()
        
        # filter the corresponding data
        data = df.loc[df.src.eq(src) & df.metric.eq(metric) & df.model.eq(model)]
        
        # if the DataFrame, data, isn't empty (e.g. there are bars for the current model
        if not data.empty:
            
            # for each grp on the x-axis, get the corresponding bar height (value, nan, 0)
            # this is to show the corresponding data and labels - this can be removed
            gh = {grp: (h := v.get_height(), data.loc[data.grp.eq(grp), 'flag'].tolist()[0] if not np.isnan(h) else '') for v, grp in zip(c, grp_unique)}

            # custom labels from the flag column
            labels = [data.loc[data.grp.eq(grp), 'flag'].tolist()[0] if not np.isnan(v.get_height()) else '' for v, grp in zip(c, grp_unique)]

            # shows the different data being used - can be removed
            print(src, metric, model)
            display(data)
            print(gh)
            print(labels)
            print('\n')
            
            # add the labels
            ax.bar_label(c, labels=labels, label_type='edge')
    ax.margins(y=0.2)

enter image description here

Printed Output

C auc gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
|  0 | C     | G1    | gbt     | auc      | 0.7999 | †      |
|  4 | C     | G2    | gbt     | auc      | 0.7554 |        |
|  7 | C     | G3    | gbt     | auc      | 0.7442 |        |
{'G1': (0.7999, '†'), 'G2': (0.7554, ''), 'G3': (0.7442, '')}
['†', '', '']


C auc gbtv2
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
|  1 | C     | G1    | gbtv2   | auc      | 0.8199 | *      |
{'G1': (0.8199, '*'), 'G2': (nan, ''), 'G3': (nan, '')}
['*', '', '']


C pr@2% gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
|  2 | C     | G1    | gbt     | pr@2%    | 0.0883 | *      |
|  5 | C     | G2    | gbt     | pr@2%    | 0.0842 |        |
|  8 | C     | G3    | gbt     | pr@2%    | 0.0894 |        |
{'G1': (0.0883, '*'), 'G2': (0.0842, ''), 'G3': (0.0894, '')}
['*', '', '']


C pr@10% gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
|  3 | C     | G1    | gbt     | pr@10%   | 0.043  |        |
|  6 | C     | G2    | gbt     | pr@10%   | 0.0572 |        |
|  9 | C     | G3    | gbt     | pr@10%   | 0.0736 |        |
{'G1': (0.043, ''), 'G2': (0.0572, ''), 'G3': (0.0736, '')}
['', '', '']


E auc gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
| 10 | E     | G1    | gbt     | auc      | 0.7988 |        |
| 14 | E     | G2    | gbt     | auc      | 0.7296 |        |
| 17 | E     | G3    | gbt     | auc      | 0.6958 |        |
{'G1': (0.7988, ''), 'G2': (0.7296, ''), 'G3': (0.6958, '')}
['', '', '']


E pr@2% gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
| 11 | E     | G1    | gbt     | pr@2%    | 0.081  |        |
| 15 | E     | G2    | gbt     | pr@2%    | 0.1071 |        |
| 18 | E     | G3    | gbt     | pr@2%    | 0.1007 |        |
{'G1': (0.081, ''), 'G2': (0.1071, ''), 'G3': (0.1007, '')}
['', '', '']


E pr@10% gbt
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
| 12 | E     | G1    | gbt     | pr@10%   | 0.0354 |        |
| 16 | E     | G2    | gbt     | pr@10%   | 0.0528 | †      |
| 19 | E     | G3    | gbt     | pr@10%   | 0.0536 | †      |
{'G1': (0.0354, ''), 'G2': (0.0528, '†'), 'G3': (0.0536, '†')}
['', '†', '†']


E pr@10% gbtv3
|    | src   | grp   | model   | metric   |    val | flag   |
|---:|:------|:------|:--------|:---------|-------:|:-------|
| 13 | E     | G1    | gbtv3   | pr@10%   | 0.0454 |        |
{'G1': (0.0454, ''), 'G2': (nan, ''), 'G3': (nan, '')}
['', '', '']
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158