1
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

array = np.array([[1,5,9],[3,5,7]])

df = pd.DataFrame(data=array, index=['Positive', 'Negative'])

f, ax = plt.subplots(figsize=(8, 6))

current_palette = sns.color_palette('colorblind')

ax_pos = sns.barplot(x = np.arange(0,3,1), y = df.loc['Positive'].to_numpy(), color = current_palette[2], alpha = 0.66)
ax_neg = sns.barplot(x = np.arange(0,3,1), y = df.loc['Negative'].to_numpy(), color = current_palette[4], alpha = 0.66)

plt.xticks(np.arange(0,3,1), fontsize = 20)
plt.yticks(np.arange(0,10,1), fontsize = 20)

plt.legend((ax_pos[0], ax_neg[0]), ('Positive', 'Negative'))

plt.tight_layout()

Which produces the follow error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[32], line 15
     12 plt.xticks(np.arange(0,3,1), fontsize = 20)
     13 plt.yticks(np.arange(0,10,1), fontsize = 20)
---> 15 plt.legend((ax_pos[0], ax_neg[0]), ('Positive', 'Negative'))
     17 plt.tight_layout()

TypeError: 'Axes' object is not subscriptable

enter image description here

I would like to know why calling legend like this (plt.legend(ax[0]...) is not possible with seaborn, whereas with matplotlib it is.

In the end, I just want the legend in the upper left corner.

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
chalbiophysics
  • 313
  • 4
  • 15

2 Answers2

0

I figured out that barplot has "label" function :

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

array = np.array([[1,5,9],[3,5,7]])

df = pd.DataFrame(data=array, index=['Positive', 'Negative'])

f, ax = plt.subplots(figsize=(8, 6))

current_palette = sns.color_palette('colorblind')

sns.barplot(x = np.arange(0,3,1), y = df.loc['Positive'].to_numpy(), color = current_palette[2], alpha = 0.66, label = "Positive")
sns.barplot(x = np.arange(0,3,1), y = df.loc['Negative'].to_numpy(), color = current_palette[4], alpha = 0.66, label = "Negative")

plt.xticks(np.arange(0,3,1), fontsize = 20)
plt.yticks(np.arange(0,10,1), fontsize = 20)

plt.legend(frameon = False)

plt.tight_layout()

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
chalbiophysics
  • 313
  • 4
  • 15
0
  • The bar plot in the OP, and the other answer, is wrong. The bars are layered in the z-order, not stacked, which is non-standard, and likely to be misinterpreted. As shown in this plot.
  • The correct way to plot a wide dataframe, with labeled index values, is with pandas.DataFrame.plot.
    • The values in the index will be the xticklabels, and the column names will be the legend labels.
    • Depending on how the data should be presented, use .T to transpose the index and columns.
  • In general, stacked bars are not the best way to present data because, beyond two groups, it's difficult to accurately compare the lengths.
    • Grouped bars are better because the bar lengths are easier to compare.
  • As this answer shows, it's better to correctly use the plotting API, for either seaborn or pandas, to manage the legend.
    • pandas uses matplotlib as the default plotting backend, and seaborn is a high-level API for matplotlib.
  • In this case, df already has labels to use for the legend.
    • If the dataframe does not have appropriate labels, it is better to manage, transform, and clean the dataframe, and let the plotting API deal with legends, labels, etc.
      • Use .rename to rename columns and index values.
      • Use .map to add a new column to be used for hue in seaborn.
  • See Move seaborn plot legend to a different position and How to put the legend outside the plot for a thorough exploration of the moving legends options.
  • Tested in python 3.11.3, pandas 2.0.2, matplotlib 3.7.1, seaborn 0.12.2
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# sample wide dataframe
df = pd.DataFrame(data=np.array([[1, 5, 9],[3, 5, 7]]),
                  index=['Positive', 'Negative'])

# covert to long form
dfm = df.melt(ignore_index=False).reset_index(names='Cat')

Stacked bars with pandas.DataFrame.plot

  • This requires one line of code to implement.
# transpose and plot the dataframe
ax = df.T.plot(kind='bar', stacked=True, rot=0, figsize=(8, 6))

# cosmetics
ax.spines[['top', 'right']].set_visible(False)
ax.legend(title='Cat', bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)

enter image description here

  • with stacked=False, which is the better option.

enter image description here

Stacked bars with seaborn.histplot

  • seaborn works best with data in a long form, so dfm is used.
    • sns.barplot does not do stacked bars, so sns.histplot must be used.
  • histplot is for showing the distribution of continuous data, and since the column passed to the x-axis is numeric, the axis is presented as continuous, not discrete categories.
fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(data=dfm, x='variable', weights='value', hue='Cat', multiple='stack', discrete=True, ax=ax)

# cosmetics
ax.spines[['top', 'right']].set_visible(False)
sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)

enter image description here

Grouped bars with seaborn.barplot

fig, ax = plt.subplots(figsize=(8, 6))
sns.barplot(data=dfm, x='variable', y='value', hue='Cat', ax=ax)

# cosmetics
ax.spines[['top', 'right']].set_visible(False)
sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158