I would like to have a horizontal stacked bar chart with hierarchy labels on y axis. I have searched a bit, and found the following nice example and code.
But it is for a vertical stacked bar chart. I want to apply it to a horizontal bar chart, so I simply changed kind='barh'
, but this won't work.
I managed to delete the default ylabels by changing all x to y in the last few lines. But changing x to y in the functions defined didn't give me what I want: the hierarchy labels are still on x axis.
Can anyone help? Thanks.
P.S.: to make things less messy, I posted the original code I found from the 2nd answer to this question rather than the one I tried to modify
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
def test_table():
data_table = pd.DataFrame({'Room': ['Room A'] * 4 + ['Room B'] * 3,
'Shelf': ['Shelf 1'] * 2 + ['Shelf 2'] * 2 + ['Shelf 1'] * 2 + ['Shelf 2'],
'Staple':['Milk', 'Water', 'Sugar', 'Honey', 'Wheat', 'Corn', 'Chicken'],
'Quantity': [10, 20, 5, 6, 4, 7, 2,],
'Ordered': np.random.randint(0, 10, 7)
})
data_table
def add_line(ax, xpos, ypos):
line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
transform=ax.transAxes, color='black')
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
ypos = -.1
scale = 1./df.index.size
for level in range(df.index.nlevels)[::-1]:
pos = 0
for label, rpos in label_len(df.index,level):
lxpos = (pos + .5 * rpos)*scale
ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
add_line(ax, pos*scale, ypos)
pos += rpos
add_line(ax, pos*scale , ypos)
ypos -= .1
df = test_table().groupby(['Room','Shelf','Staple']).sum()
fig = plt.figure()
ax = fig.add_subplot(111)
df.plot(kind='bar',stacked=True,ax=fig.gca())
#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()