1

I am trying to add 8 heatmaps (subplots) to a figure but I can't seem to manage it. Could you please help me?

# in order to modify the size
fig = plt.figure(figsize=(12,8))
# adding multiple Axes objects  
fig, ax_lst = plt.subplots(2, 4)  # a figure with a 2x4 grid of Axes

letter = "ABCDEFGH"

for character in letter:
    x = np.random.randn(4096)
    y = np.random.randn(4096)
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=(64,64))
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    # Plot heatmap
    plt.clf()
    plt.title('Pythonspot.com heatmap example')
    plt.ylabel('y')
    plt.xlabel('x')
    plt.imshow(heatmap, extent=extent)
    plt.colorbar()
    plt.show()

Thank you!

sachikox
  • 15
  • 5

2 Answers2

2

Similarly to this answer, you could do the following:

letter = "ABCDEFGH"

n_cols = 2
fig, axes = plt.subplots(nrows=int(np.ceil(len(letter)/n_cols)), 
                         ncols=n_cols, 
                         figsize=(15,15))

for _, ax in zip(letter, axes.flatten()):
    x = np.random.randn(4096)
    y = np.random.randn(4096)
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=(64,64))
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    # Plot heatmap
    ax.set_title('Pythonspot.com heatmap example')
    ax.set_ylabel('y')
    ax.set_xlabel('x')
    ax.imshow(heatmap, extent=extent)
plt.tight_layout()  
plt.show()

enter image description here

yatu
  • 86,083
  • 12
  • 84
  • 139
1

First you were creating two separate figures by calling plt.figure() and then plt.subplots() (that last function creates both a figure and an array of axes)

Then you need to iterate over your axes, and plot on each one of these axes, instead of clearing the figure at each loop (which is what you were doing using plt.clf())

You can use the plt.XXXX() functions, but those only work on the "current" axes, so you have to change the current axes at each iteration. Otherwise, you would be better off using the Axes.set_XXXX() functions, like in the other answer by @yatu. See here for more information. fig, ax_lst = plt.subplots(2, 4, figsize=(12,8)) # a figure with a 2x4 grid of Axes

letters = "ABCDEFGH"
for character,ax in zip(letters, ax_lst.flat):
    x = np.random.randn(4096)
    y = np.random.randn(4096)
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=(64,64))
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    # Plot heatmap
    plt.sca(ax) # make the ax object the "current axes"
    plt.title(character)
    plt.ylabel('y')
    plt.xlabel('x')
    plt.imshow(heatmap, extent=extent)
    plt.colorbar()
plt.show()
Diziet Asahi
  • 38,379
  • 7
  • 60
  • 75