460

I'm learning to use matplotlib by studying examples, and a lot of examples seem to include a line like the following before creating a single plot...

fig, ax = plt.subplots()

Here are some examples...

I see this function used a lot, even though the example is only attempting to create a single chart. Is there some other advantage? The official demo for subplots() also uses f, ax = subplots when creating a single chart, and it only ever references ax after that. This is the code they use.

# Just a figure and one subplot
f, ax = plt.subplots()
ax.plot(x, y)
ax.set_title('Simple plot')
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
neelshiv
  • 6,125
  • 6
  • 21
  • 35
  • 19
    well, it is short, concise and you get a reference two both figure and axis in a single step. It's pythonic, because it's beautiful :) – cel Dec 08 '15 at 17:51

6 Answers6

535

plt.subplots() is a function that returns a tuple containing a figure and axes object(s). Thus when using fig, ax = plt.subplots() you unpack this tuple into the variables fig and ax. Having fig is useful if you want to change figure-level attributes or save the figure as an image file later (e.g. with fig.savefig('yourfilename.png')). You certainly don't have to use the returned figure object but many people do use it later so it's common to see. Also, all axes objects (the objects that have plotting methods), have a parent figure object anyway, thus:

fig, ax = plt.subplots()

is more concise than this:

fig = plt.figure()
ax = fig.add_subplot(111)
jonchar
  • 6,183
  • 1
  • 14
  • 19
  • 20
    Very good explanation. Here is the doc on it - http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.subplots – bretcj7 Feb 08 '17 at 22:19
  • 5
    Why do we always use 111 as a parameter in subplot? – Priyansh Sep 15 '17 at 04:01
  • 1
    @Priyansh because it's inherited from Matlab syntax – pcko1 Apr 03 '19 at 11:09
  • 4
    @Priyansh Not always. If you want 4 graphs, you could have `ax11, ax12, ax21, ax22` by using `fig.add_subplot(221)` (or 222, 223, 224 respectively). – Guimoute Feb 11 '20 at 20:34
  • Additional question then: Is it correct to do this: ```fig = plt.figure(figsize=[20,15])``` and then ```axes = fig.subplots(2,2, sharex=False, sharey=False)``` – miguello Dec 21 '21 at 00:20
87

Just a supplement here.

The following question is that what if I want more subplots in the figure?

As mentioned in the Doc, we can use fig = plt.subplots(nrows=2, ncols=2) to set a group of subplots with grid(2,2) in one figure object.

Then as we know, the fig, ax = plt.subplots() returns a tuple, let's try fig, ax1, ax2, ax3, ax4 = plt.subplots(nrows=2, ncols=2) firstly.

ValueError: not enough values to unpack (expected 4, got 2)

It raises a error, but no worry, because we now see that plt.subplots() actually returns a tuple with two elements. The 1st one must be a figure object, and the other one should be a group of subplots objects.

So let's try this again:

fig, [[ax1, ax2], [ax3, ax4]] = plt.subplots(nrows=2, ncols=2)

and check the type:

type(fig) #<class 'matplotlib.figure.Figure'>
type(ax1) #<class 'matplotlib.axes._subplots.AxesSubplot'>

Of course, if you use parameters as (nrows=1, ncols=4), then the format should be:

fig, [ax1, ax2, ax3, ax4] = plt.subplots(nrows=1, ncols=4)

So just remember to keep the construction of the list as the same as the subplots grid we set in the figure.

Hope this would be helpful for you.

Duskash
  • 871
  • 6
  • 3
  • 2
    don't forget to add plt.tight_layout() if your subplots have titles – gota Mar 12 '18 at 22:21
  • 28
    What if you have a lot of subplots? It's easier to do it this way: `fig, axes = plt.subplots(nrows=10, ncols=3)` and `axes = axes.flatten()`. Now you can refer to each subplot by its index: `axes[0]`, `axes[1]`, ... – Guillaume May 25 '18 at 02:06
  • What if I want one of those subplots to span multiple cols or rows? is it doable with the subplots command? – gota Sep 13 '18 at 13:18
  • @Guillaume Not sure why you'd want to flatten the `axes` list. By default it is returned in the shape specified, so that with `nrows=10, ncols=3`, `axes` has 10 rows and 3 columns and `axes[row][col]` does what you'd expect. – BallpointBen Mar 23 '20 at 17:32
  • 1
    @BallpointBen I'm not sure that works if `nrows=1`, as then `axes` is returned flat with length equal to `ncols` – Ben Apr 13 '20 at 10:34
  • 3
    @BallpointBen Just realised you can fix this by doing: `fig, axes = plt.subplots(nrows=1, ncols=3, squeeze=False)` – Ben Apr 13 '20 at 10:39
  • 2
    @BallpointBen What if you use a script to run through the consecutive subplots? You don't need to do some `if col > row: col -= row; row += 1` because if you flatten it, you simply walk through.. – BUFU Jun 04 '20 at 14:04
37

As a supplement to the question and above answers there is also an important difference between plt.subplots() and plt.subplot(), notice the missing 's' at the end.

One can use plt.subplots() to make all their subplots at once and it returns the figure and axes (plural of axis) of the subplots as a tuple. A figure can be understood as a canvas where you paint your sketch.

# create a subplot with 2 rows and 1 columns
fig, ax = plt.subplots(2,1)

Whereas, you can use plt.subplot() if you want to add the subplots separately. It returns only the axis of one subplot.

fig = plt.figure() # create the canvas for plotting
ax1 = plt.subplot(2,1,1) 
# (2,1,1) indicates total number of rows, columns, and figure number respectively
ax2 = plt.subplot(2,1,2)

However, plt.subplots() is preferred because it gives you easier options to directly customize your whole figure

# for example, sharing x-axis, y-axis for all subplots can be specified at once
fig, ax = plt.subplots(2,2, sharex=True, sharey=True)

Shared axes whereas, with plt.subplot(), one will have to specify individually for each axis which can become cumbersome.

Light_B
  • 1,660
  • 1
  • 14
  • 28
10

In addition to the answers above, you can check the type of object using type(plt.subplots()) which returns a tuple, on the other hand, type(plt.subplot()) returns matplotlib.axes._subplots.AxesSubplot which you can't unpack.

John T
  • 101
  • 1
  • 2
  • 5
    Welcome to Stack Overflow! This is really a comment, not an answer. With a bit more rep, [you will be able to post comments](//stackoverflow.com/privileges/comment). Thanks! – Miroslav Glamuzina Mar 31 '19 at 05:24
8

Using plt.subplots() is popular because it gives you an Axes object and allows you to use the Axes interface to define plots.

The alternative would be to use the global state interface, the plt.plot etc functionality:

import matplotlib.pyplot as plt

# global state version - modifies "current" figure
plt.plot(...)
plt.xlabel(...)

# axes version - modifies explicit axes
ax.plot(...)
ax.set_xlabel(...)

So why do we prefer Axes?

  • It is refactorable - you can put away some of the code into a function that takes an Axes object, and does not rely on global state
  • It is easier to transition to a situation with multiple subplots
  • One consistent/familiar interface instead of switching between two
  • The only way to access the depth of all features of matplotlib

The global state version was created that way to be easy to use interactively, and to be a familiar interface for Matlab users, but in larger programs and scripts the points outlined here favour using the Axes interface.

There is a matplotlib blog post exploring this topic in more depth: Pyplot vs Object Oriented Interface

It is relatively easy to deal with both worlds. We can for example always ask for the current axes: ax = plt.gca() ("get current axes").

creanion
  • 2,319
  • 2
  • 13
  • 17
  • 1
    For a long time I've wondered why the interface was so confusing (e.g. `plt.xlabel` vs `ax.set_xlabel`) but now it makes sense - these are 2 separate interfaces! – Alex Shroyer Feb 14 '22 at 13:35
0

fig.tight_layout()

such a feature is very convenient, if xticks_labels goes out of plot-window, such a line helps to fit xticks_labels & the whole chart to the window, if automatic positioning of chart in plt-window works not correctly. And this code-line works only if you use fig-object in the plt-window

fig, ax = plt.subplots(figsize=(10,8))

myData.plot(ax=ax)
plt.xticks(fontsize=10, rotation=45)

fig.tight_layout()
plt.show()
JeeyCi
  • 354
  • 2
  • 9