110

Looking at the matplotlib documentation, it seems the standard way to add an AxesSubplot to a Figure is to use Figure.add_subplot:

from matplotlib import pyplot

fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
ax.hist( some params .... )

I would like to be able to create AxesSubPlot-like objects independently of the figure, so I can use them in different figures. Something like

fig = pyplot.figure()
histoA = some_axes_subplot_maker.hist( some params ..... )
histoA = some_axes_subplot_maker.hist( some other params ..... )
# make one figure with both plots
fig.add_subaxes(histo1, 211)
fig.add_subaxes(histo1, 212)
fig2 = pyplot.figure()
# make a figure with the first plot only
fig2.add_subaxes(histo1, 111)

Is this possible in matplotlib and if so, how can I do this?

Update: I have not managed to decouple creation of Axes and Figures, but following examples in the answers below, can easily re-use previously created axes in new or olf Figure instances. This can be illustrated with a simple function:

def plot_axes(ax, fig=None, geometry=(1,1,1)):
    if fig is None:
        fig = plt.figure()
    if ax.get_geometry() != geometry :
        ax.change_geometry(*geometry)
    ax = fig.axes.append(ax)
    return fig
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
juanchopanza
  • 223,364
  • 34
  • 402
  • 480

5 Answers5

57

Typically, you just pass the axes instance to a function.

For example:

import matplotlib.pyplot as plt
import numpy as np

def main():
    x = np.linspace(0, 6 * np.pi, 100)

    fig1, (ax1, ax2) = plt.subplots(nrows=2)
    plot(x, np.sin(x), ax1)
    plot(x, np.random.random(100), ax2)

    fig2 = plt.figure()
    plot(x, np.cos(x))

    plt.show()

def plot(x, y, ax=None):
    if ax is None:
        ax = plt.gca()
    line, = ax.plot(x, y, 'go')
    ax.set_ylabel('Yabba dabba do!')
    return line

if __name__ == '__main__':
    main()

To respond to your question, you could always do something like this:

def subplot(data, fig=None, index=111):
    if fig is None:
        fig = plt.figure()
    ax = fig.add_subplot(index)
    ax.plot(data)

Also, you can simply add an axes instance to another figure:

import matplotlib.pyplot as plt

fig1, ax = plt.subplots()
ax.plot(range(10))

fig2 = plt.figure()
fig2.axes.append(ax)

plt.show()

Resizing it to match other subplot "shapes" is also possible, but it's going to quickly become more trouble than it's worth. The approach of just passing around a figure or axes instance (or list of instances) is much simpler for complex cases, in my experience...

Community
  • 1
  • 1
Joe Kington
  • 275,208
  • 71
  • 604
  • 463
  • 1
    +1 This is useful, but it seems to me that the axes are still coupled to the figures, and/or to some state in pyplot. I cannot really decouple the axis creation from the figure making and plotting following your example. – juanchopanza Jun 10 '11 at 17:17
  • Axes are fundamentally linked to a specific figure in matplotlib. There's no way around this. However, you can still completely "decouple the axis creation from the figure making and plotting" by just passing axes and figure objects around. I'm not quite sure I follow what you're wanting to do... – Joe Kington Jun 10 '11 at 17:41
  • Well, actually, I guess they're not as fundementally linked as I thought. You can just add the same axes to a different figure. (Just do `fig2.axes.append(ax1)`) Resizing it to match different subplot shapes is also possible. This is probably going to wind up being more trouble than it's worth, though... – Joe Kington Jun 10 '11 at 17:52
  • I think I can achieve what I want appending to Figure.axes as in your example, plus re-sizing the AxesSubPlot with change_geometry. – juanchopanza Jun 12 '11 at 07:41
  • 8
    adding axes instance to another figure (last example) doesn't work for me in Enthought 7.3-2 (matplotlib 1.1.0). – aaren Dec 07 '12 at 12:53
  • 18
    @aaren - It's not working because the way the axes stack for a figure works has been changed in newer versions of matplotlib. Axes deliberately aren't supposed to be shared between different figures now. As a workaround, you could do this `fig2._axstack.add(fig2._make_key(a), a)`, but it's hackish and likely to change in the future. It seems to work properly, but it may break some things. – Joe Kington Dec 07 '12 at 13:21
  • yeah, I wouldn't actually do it that way anyway. passing the axes to a function that plots on it makes more sense as it decouples figure layout and axes creation (layout) from the actual plotting (content) – aaren Dec 07 '12 at 15:31
  • 4
    10 years later and I still wonder why I insist on using `matplotlib` – Eduardo Pignatelli Apr 06 '21 at 09:06
  • Spending the whole morning to get here and it does not work (fig2 is blank). – Barzi2001 Oct 13 '22 at 10:50
38

The following shows how to "move" an axes from one figure to another. This is the intended functionality of @JoeKington's last example, which in newer matplotlib versions is not working anymore, because axes cannot live in several figures at once.

You would first need to remove the axes from the first figure, then append it to the next figure and give it some position to live in.

import matplotlib.pyplot as plt

fig1, ax = plt.subplots()
ax.plot(range(10))
ax.remove()

fig2 = plt.figure()
ax.figure=fig2
fig2.axes.append(ax)
fig2.add_axes(ax)

dummy = fig2.add_subplot(111)
ax.set_position(dummy.get_position())
dummy.remove()
plt.close(fig1)

plt.show()
ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
  • 1
    small addition: when `plt.show()` is replaced by `fig2.savefig('out.png', dpi=300)` the positioning is messed up due to the `dpi` keyword. This can be avoided by setting the final `dpi` when `ax` is initialized: `fig1, ax = plt.subplots(dpi=300)` – Matthias123 Nov 07 '17 at 09:25
  • In my Python shell, it doesn't look like this line does anything: fig2.axes.append(ax) – Spirko Jan 07 '18 at 16:47
  • `ax.remove()` results in `NotImplementedError: cannot remove artist` (Python 3.7.0, matplotlib 2.2.2) – gerrit Aug 14 '18 at 10:52
  • @gerrit Did you change anything compared to the snippet above? It works fine for me. – ImportanceOfBeingErnest Aug 14 '18 at 11:29
  • 1
    @ImportanceOfBeingErnest Yes; I obtained my figure from a `pickle`. I apologise for omitting this important detail. I ended up setting 9 AxesSubplot to `set_visible(False)` and changing the position of the one I wanted to show only. – gerrit Aug 14 '18 at 11:48
  • 1
    @gerrit Maybe you need [this answer](https://stackoverflow.com/a/48915069/4124317)? – ImportanceOfBeingErnest Aug 14 '18 at 12:04
  • @ImportanceOfBeingErnest Maybe. I tried [this answer](https://stackoverflow.com/a/6310298/974555), but the call to `add_line` resulted in `RuntimeError: Can not put single artist in more than one figure`. Meanwhile I've found a different workaround, though. – gerrit Aug 14 '18 at 12:08
  • @Matthias123 I got the same problem and used your solution to fix the `.png` output, but it doesn't work if I change the extension to `.eps`. Any ideas how to make the eps look the same as the png and plt.show() output? Thanks a lot! – irene May 25 '19 at 06:16
  • 2
    @irene Note that this solution only moves the axes to a new figure, it does not set any of the transforms. So such issues are expected. Since it is discouraged to move artists between figures, better don't use this if you need reliable output. – ImportanceOfBeingErnest May 25 '19 at 12:03
  • based on this and on the update on the question I tried the following code: https://gist.github.com/homerobse/d8b4cbaff3125f0404c77dc90c21abe0 The axes went to the new figure, but they did not resize appropriately to the size of the new figure. Also, when trying to `fig.savefig('newfig-witholdaxes.svg', format='svg')` it was cut in different places than the one displayed with `plt.show()` – Homero Esmeraldo Jun 04 '20 at 21:58
  • How would this work if the second figure was given by `fig2, ax=plt.subplots(ncols=2); ax=ax.flatten()`? – Sos Aug 05 '20 at 07:56
  • Remark, `ax.remove()` can also be `fig.delaxes(ax)`. – user202729 Jun 26 '22 at 02:36
4

For line plots, you can deal with the Line2D objects themselves:

fig1 = pylab.figure()
ax1 = fig1.add_subplot(111)
lines = ax1.plot(scipy.randn(10))

fig2 = pylab.figure()
ax2 = fig2.add_subplot(111)
ax2.add_line(lines[0])
Steve Tjoa
  • 59,122
  • 18
  • 90
  • 101
  • +1 Good example. It seems I cannot decouple the axes creation from figure creation, but I can grab the axes instance and pass it to new figures. – juanchopanza Jun 12 '11 at 07:40
  • 1
    Please note that this approach does not work anymore. See Joe Kington's comment from Dec 7 2012, on his answer above. – joelostblom Oct 04 '16 at 12:06
  • 3
    `ax2.add_line(lines[0])` results in `RuntimeError: Can not put single artist in more than one figure` (Python 3.7.0, matplotlib 2.2.2). – gerrit Aug 14 '18 at 10:55
1

TL;DR based partly on Joe nice answer.

Opt.1: fig.add_subplot()

def fcn_return_plot():
    return plt.plot(np.random.random((10,)))
n = 4
fig = plt.figure(figsize=(n*3,2))
#fig, ax = plt.subplots(1, n,  sharey=True, figsize=(n*3,2)) # also works
for index in list(range(n)):
    fig.add_subplot(1, n, index + 1)
    fcn_return_plot()
    plt.title(f"plot: {index}", fontsize=20) 

Opt.2: pass ax[index] to a function that returns ax[index].plot()

def fcn_return_plot_input_ax(ax=None):
    if ax is None:
        ax = plt.gca()
    return ax.plot(np.random.random((10,)))
n = 4
fig, ax = plt.subplots(1, n,  sharey=True, figsize=(n*3,2))
for index in list(range(n)):
    fcn_return_plot_input_ax(ax[index])
    ax[index].set_title(f"plot: {index}", fontsize=20)

Outputs respect. enter image description here enter image description here

Note: Opt.1 plt.title() changed in opt.2 to ax[index].set_title(). Find more Matplotlib Gotchas in Van der Plas book.

Xopi García
  • 359
  • 1
  • 2
  • 9
0

To go deeper in the rabbit hole. Extending my previous answer, one could return a whole ax, and not ax.plot() only. E.g.

If dataframe had 100 tests of 20 types (here id):

dfA = pd.DataFrame(np.random.random((100,3)), columns = ['y1', 'y2', 'y3'])
dfB = pd.DataFrame(np.repeat(list(range(20)),5), columns = ['id'])
dfC = dfA.join(dfB)

And the plot function (this is the key of this whole answer):

def plot_feature_each_id(df, feature, id_range=[], ax=None, legend_bool=False):
    feature = df[feature]
    if not len(id_range): id_range=set(df['id'])
    legend_arr = []
    for k in id_range:
        pass
        mask = (df['id'] == k)
        ax.plot(feature[mask])
        legend_arr.append(f"id: {k}")
    if legend_bool: ax.legend(legend_arr)
    return ax

We can achieve:

feature_arr = dfC.drop('id',1).columns
id_range= np.random.randint(len(set(dfC.id)), size=(10,))
n = len(feature_arr)
fig, ax = plt.subplots(1, n,  figsize=(n*6,4));
for i,k in enumerate(feature_arr):
    plot_feature_each_id(dfC, k, np.sort(id_range), ax[i], legend_bool=(i+1==n))
    ax[i].set_title(k, fontsize=20)
    ax[i].set_xlabel("test nr. (id)", fontsize=20)

enter image description here

Xopi García
  • 359
  • 1
  • 2
  • 9