0

In trying to do some small multiple stuff, I want to make a bunch of subplots with Matplotlib and toss in varying data to each. pyplot.subplots() gives me a Figure and Numpy array of Axes, but in trying to iterate over the axes, I am stumped on how to actually get at them.

I'm trying something akin to:

import numpy as np
import matplotlib.pyplot as plt
f, axs = plt.subplots(2,2)
for ax in np.nditer(axs, flags=['refs_ok']):
    ax.plot([1,2],[3,4])

However, the type of ax in each iteration is not an Axes, but rather an ndarray, so attempting to plot fails with:

AttributeError: 'numpy.ndarray' object has no attribute 'plot'

How can I loop over my axes?

Nick T
  • 25,754
  • 12
  • 83
  • 121

2 Answers2

3

You can do this more simply:

for ax in axs.ravel():
    ax.plot(...)
tacaswell
  • 84,579
  • 22
  • 210
  • 199
  • `.ravel()` points to original array, while `.flat()` creates a copy. See also http://stackoverflow.com/questions/28930465/what-is-the-difference-between-flatten-and-ravel-functions-in-numpy – Martin Nov 22 '16 at 10:17
0

Numpy arrays have a .flat attribute that returns a 1-D iterator:

for ax in axs.flat:
    ax.plot(...)

Another option is reshaping the array to a single dimension:

for ax in np.reshape(axs, -1):
    ax.plot(...)
Nick T
  • 25,754
  • 12
  • 83
  • 121