1

I want to do something similar to the following in pyplot where I can label the bars and the group as well. I was able to label the bars but don't know how to label the groups. Thanks!

enter image description here

JohanC
  • 71,591
  • 8
  • 33
  • 66
paul-shuvo
  • 1,874
  • 4
  • 33
  • 37
  • Does it make sense to have ["Apples", "Pears", ...] on the x-axis and years as legends rather than on the X axis as well? – adam Jul 29 '20 at 09:08
  • This question https://stackoverflow.com/questions/59349185/non-linear-second-axis-in-matplotlib might provide the way forward. – Ed Smith Jul 29 '20 at 09:55

1 Answers1

1

A default seaborn barplot would create a legend to show the years:

import numpy as np
import pandas as pd
import seaborn as sns

labels = ['apples', 'bananas', 'coconuts', 'dates', 'elderberries', 'figs', 'grapes']
years = [2017, 2018, 2019]
df = pd.DataFrame({'Fruit': np.tile(labels, len(years)),
                   'Year': np.tile(years, len(labels)),
                   'Amount': np.random.uniform(1.5, 5, len(labels)*len(years))})
ax = sns.barplot(x='Fruit', y='Amount', hue='Year', data=df)

default plot

Placing text() using a BlendedTransform can place new labels at the desired positions. clip_on=False is needed to draw outside the main axes area. Vertical lines can be added via axvline (these use the BlendedTransform by default). New year labels can be put at the position of each bar. The legend and the default xlabel can be removed.

from matplotlib import pyplot as plt
from matplotlib import transforms
import numpy as np
import pandas as pd
import seaborn as sns

labels = ['apples', 'bananas', 'coconuts', 'dates', 'elderberries', 'figs', 'grapes']
years = [2017, 2018, 2019]
df = pd.DataFrame({'Fruit': np.tile(labels, len(years)),
                   'Year': np.tile(years, len(labels)),
                   'Amount': np.random.uniform(1.5, 5, len(labels)*len(years))})
fig, ax = plt.subplots(figsize=(12, 4))
ax = sns.barplot(x='Fruit', y='Amount', hue='Year', palette='Reds', data=df, ax=ax)
year_pos = np.sort( [p.get_x() + p.get_width()/2  for p in ax.patches])
ax.set_xticks(year_pos)
ax.set_xticklabels(np.tile(years, len(labels)), rotation=30)
ax.get_legend().remove()
ax.set_xlabel('') # remove default xlabel
fruit_pos = year_pos.reshape(-1, len(years)).mean(axis=1)
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)

for pos, label in zip(fruit_pos, labels):
    ax.text(pos, -0.25,  label, transform=trans, ha='center', va='bottom', color='steelblue', fontsize=14)
for pos in (fruit_pos[:-1] + fruit_pos[1:]) / 2:
    ax.axvline(pos, 0, -0.25, color='steelblue', ls=':' , clip_on=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

updated plot

JohanC
  • 71,591
  • 8
  • 33
  • 66