72

Is there a simple/clean way to iterate an array of axis returned by subplots like

nrow = ncol = 2
a = []
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for i, row in enumerate(axs):
    for j, ax in enumerate(row):
        a.append(ax)

for i, ax in enumerate(a):
    ax.set_ylabel(str(i))

which even works for nrow or ncol == 1.

I tried list comprehension like:

[element for tupl in tupleOfTuples for element in tupl]

but that fails if nrows or ncols == 1

Georgy
  • 12,464
  • 7
  • 65
  • 73
greole
  • 4,523
  • 5
  • 29
  • 49
  • The solution which also works with a single axis is [this one](https://stackoverflow.com/a/33649912/774575). – mins Dec 15 '22 at 15:53

6 Answers6

97

The ax return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.

nrow = 1; ncol = 2;
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)

for ax in axs.reshape(-1): 
  ax.set_ylabel(str(i))

This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:

nrow = 1; ncol = 1;
fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
axs = np.array(axs)

for ax in axs.reshape(-1):
  ax.set_ylabel(str(i))

reshape docs. The argument -1 causes reshape to infer dimensions of the output.

Max Ghenis
  • 14,783
  • 16
  • 84
  • 132
Bonlenfum
  • 19,101
  • 2
  • 53
  • 56
  • 14
    For `nrow=ncol=1` you can use `squeeze=0`. `plt.subplots(nrows=nrow, ncols=nrow, squeeze=0)` always returns a 2 dimensional array for the axes, even if both are one. – cronos Apr 06 '16 at 11:26
  • This suggestion will not work because axs is not a numpy array in the case nrow=ncol=1. squeeze=0 works! – Julek Apr 19 '23 at 16:31
69

The fig return value of plt.subplots has a list of all the axes. To iterate over all the subplots in a figure you can use:

nrow = 2
ncol = 2
fig, axs = plt.subplots(nrow, ncol)
for i, ax in enumerate(fig.axes):
    ax.set_ylabel(str(i))

This also works for nrow == ncol == 1.

Ø. Jensen
  • 943
  • 6
  • 10
23

I am not sure when it was added, but there is now a squeeze keyword argument. This makes sure the result is always a 2D numpy array. Turning that into a 1D array is easy:

fig, ax2d = subplots(2, 2, squeeze=False)
axli = ax2d.flatten()

Works for any number of subplots, no trick for single ax, so a little easier than the accepted answer (perhaps squeeze didn't exist yet back then).

Mark
  • 18,730
  • 7
  • 107
  • 130
17

Matplotlib has its own flatten function on axes.

Why don't you try following code?

fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'
Sukjun Kim
  • 187
  • 1
  • 5
11

TLDR; axes.flat is the most pythonic way of iterating through axes

As others have pointed out, the return value of plt.subplots() is a numpy array of Axes objects, thus there are a ton of built-in numpy methods for flattening the array. Of those options axes.flat is the least verbose access method. Furthermore, axes.flatten() returns a copy of the array whereas axes.flat returns an iterator to the array. This means axes.flat will be more efficient in the long run.

Stealing @Sukjun-Kim's example:

fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'

sources: axes.flat docs Matplotlib tutorial

Devin Cody
  • 111
  • 1
  • 3
3

Here is a good practice:
For example, we need a set up four by four subplots so we can have them like below:

rows = 4; cols = 4;
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(20, 16), squeeze=0, sharex=True, sharey=True)
axes = np.array(axes)

for i, ax in enumerate(axes.reshape(-1)):
  ax.set_ylabel(f'Subplot: {i}')

The output is beautiful and clear.

Reza K Ghazi
  • 347
  • 2
  • 9