18

I would like to plot two horizontal bar charts sharing same y axis. For example, the following question shows how to achieve this in R:

Two horizontal bar charts with shared axis in ggplot2 (similar to population pyramid)

How can I create a similar plot with Python?

The plot from the question above looks like this:

enter image description here

Here is the list of states used in the graph above (the y axis):

["AK", "TX", "CA", "MT", "NM", "AZ", "NV", "CO", "OR", "WY", 
 "MI", "MN", "UT", "ID", "KS", "NE", "SD", "WA", "ND", "OK"]

Here is the list of the numbers of sales staff for each state:

[20,30,40,10,15,35,18,25,22,7,12,22,3,4,5,8,14,28,24,32]

The sales figures can be random.

Community
  • 1
  • 1
Phuong Dao
  • 301
  • 1
  • 2
  • 5

2 Answers2

38

Generally speaking, if the two variables you're displaying are in different units or have different ranges, you'll want to use two subplots with shared y-axes for this. This is similar to what the answer by @regdoug does, but it's best to explicitly share the y-axis to ensure that your data stays aligned (e.g. try zooming/panning with this example).

For example:

import matplotlib.pyplot as plt

y = range(20)
x1 = range(20)
x2 = range(0, 200, 10)

fig, axes = plt.subplots(ncols=2, sharey=True)
axes[0].barh(y, x1, align='center', color='gray')
axes[1].barh(y, x2, align='center', color='gray')
axes[0].invert_xaxis()
plt.show()

enter image description here

If you want to more precisely reproduce the example shown in the question you linked to (I'm leaving off the gray background and white grids, but those are easy to add, if you prefer them):

import numpy as np
import matplotlib.pyplot as plt

# Data
states = ["AK", "TX", "CA", "MT", "NM", "AZ", "NV", "CO", "OR", "WY", "MI",
          "MN", "UT", "ID", "KS", "NE", "SD", "WA", "ND", "OK"]
staff = np.array([20, 30, 40, 10, 15, 35, 18, 25, 22, 7, 12, 22, 3, 4, 5, 8,
                  14, 28, 24, 32])
sales = staff * (20 + 10 * np.random.random(staff.size))

# Sort by number of sales staff
idx = staff.argsort()
states, staff, sales = [np.take(x, idx) for x in [states, staff, sales]]

y = np.arange(sales.size)

fig, axes = plt.subplots(ncols=2, sharey=True)
axes[0].barh(y, staff, align='center', color='gray', zorder=10)
axes[0].set(title='Number of sales staff')
axes[1].barh(y, sales, align='center', color='gray', zorder=10)
axes[1].set(title='Sales (x $1000)')

axes[0].invert_xaxis()
axes[0].set(yticks=y, yticklabels=states)
axes[0].yaxis.tick_right()

for ax in axes.flat:
    ax.margins(0.03)
    ax.grid(True)

fig.tight_layout()
fig.subplots_adjust(wspace=0.09)
plt.show()

enter image description here

One caveat. I haven't actually aligned the y-tick-labels correctly. It is possible to do this, but it's more of a pain than you might expect. Therefore, if you really want y-tick-labels that are always perfectly centered in the middle of the figure, it's easiest to draw them a different way. Instead of axes[0].set(yticks=y, yticklabels=states), you'd do something like:

axes[0].set(yticks=y, yticklabels=[])
for yloc, state in zip(y, states):
    axes[0].annotate(state, (0.5, yloc), xycoords=('figure fraction', 'data'),
                     ha='center', va='center')
Joe Kington
  • 275,208
  • 71
  • 604
  • 463
3

Using some information I found in the matplotlib mailing list, I adapted one of the matplotlib horizontal bar chart example to make a pyramid plot.

The pyramid_plot function listed below will plot horizontal bars side-by-side.

def pyramid_plot(ylabels, data_left, xlabel_left, data_right, xlabel_right, fig=None, **kwargs):
    if(fig is None):
        fig = plt.figure()

    y_pos = np.arange(len(ylabels))
    empty_ticks = tuple('' for n in people)

    fig.add_subplot(121)
    plt.barh(y_pos, data_left, **kwargs)
    plt.yticks(y_pos, empty_ticks)
    oldlims = plt.gca().get_xlim()
    plt.axis(xmin=oldlims[1], xmax=oldlims[0])
    plt.xlabel(xlabel_left)

    fig.add_subplot(122)
    plt.barh(y_pos, data_right, **kwargs)
    plt.yticks(y_pos, ylabels)
    plt.xlabel(xlabel_right)

    return fig

The pyramid_plot function is used as follows

import matplotlib.pyplot as plt; plt.rcdefaults()
import numpy as np
import matplotlib.pyplot as plt

# Example data
people = ('Tom', 'Dick', 'Harry', 'Slim', 'Jim')
performance = 3 + 10 * np.random.rand(len(people))
salary = np.linspace(30,60,len(people))

# Plot the data
pyrfig = plt.figure(1)
pyrfig = pyramid_plot(people, salary, 'Salary (thousands)', performance, 'Performance', pyrfig, align='center', alpha=0.4)
pyrfig.suptitle('Pyramid Plot')
pyrfig.set_figwidth(1.5*pyrfig.get_figheight())
plt.show(pyrfig)

References:

http://matplotlib.org/examples/lines_bars_and_markers/barh_demo.html

https://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg11606.html

regdoug
  • 1,163
  • 9
  • 9