4

I am trying to make a figure with six individual plots, organized in two rows of three plots. Each row of plots should have its own color bar corresponding to the images shown in the three plots in a horizontal group. Laid out visually, the figure should look like:

image_type1 | image_type1 | image_type1 | colorbar_for_type1_images

image_type2 | image_type2 | image_type2 | colorbar_for_type2_images

The vertical lines in the representation above are just to separate the different components of the figure. I don't actually need vertical lines in my figure.

An example of what I'm trying to do is shown below, as well as my unsuccessful attempts to get a color bar to be plotted with the third image in each row.

I've been able to do this successfully in the past with code similar to what appears below when I was using my own color map for a series of plotted lines, rather than for images as I'm trying to do below.

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cbook import get_sample_data

#Make 6 plotting areas of the same dimensions
figuresizex = 9.0
figuresizey = 6.1
lowerx = .07
lowery = .09
upperx = .92
uppery = .97
xspace = .05
yspace = .11
xwidth = (upperx-lowerx-2*xspace)/3.
ywidth = (uppery-lowery-yspace)/2.

fig = plt.figure(figsize=(figuresizex,figuresizey))
ax1 = fig.add_axes([lowerx,lowery+ywidth+yspace,xwidth,ywidth])
ax2 = fig.add_axes([lowerx+xwidth+xspace,lowery+ywidth+yspace,xwidth,ywidth])
ax3 = fig.add_axes([lowerx+2*xwidth+2*xspace,lowery+ywidth+yspace,xwidth,ywidth])
ax4 = fig.add_axes([lowerx,lowery,xwidth,ywidth])
ax5 = fig.add_axes([lowerx+xwidth+xspace,lowery,xwidth,ywidth])
ax6 = fig.add_axes([lowerx+2*xwidth+2*xspace,lowery,xwidth,ywidth])
axlist = [ax1,ax2,ax3,ax4,ax5,ax6]

#Start plotting images
image = np.identity(5)

for i in range(0,3):
    vmin, vmax = image.min(),image.max()
    axuse = axlist[i]
    im = axuse.imshow(image, vmin=vmin, vmax=vmax)
    if i == 3:
        cbar = axuse.colorbar(im)
        cbar = plt.colorbar(im)

image_2 = np.arange(16).reshape((4,4))

for i in range(0,3):
    vmin, vmax = image_2.min(),image_2.max()
    axuse = axlist[i+3]
    axuse.imshow(image_2,vmin=vmin, vmax=vmax)
    if i == 3:
        cbar = axuse.colorbar()
        cbar = plt.colorbar()

plt.show()
NeutronStar
  • 2,057
  • 7
  • 31
  • 49

1 Answers1

3

I'd suggest taking the approach outlined in this question.

In addition to making the addition of the colorbar straightforward and not contingent on being on your third image (which should be i==2), the use of ImageGrid removes the need to explicitly (painfully?) define all 6 axes and becomes more flexible if your number of images changes.

Update: I've added a third row to show that the same scale can be applied to all images in each row by using the vmin and vmax parameters.

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

figuresizex = 9.0
figuresizey = 6.1

# generate images
image1 = np.identity(5)
image2 = np.arange(16).reshape((4,4))



fig = plt.figure(figsize=(figuresizex,figuresizey))

# create your grid objects
top_row = ImageGrid(fig, 311, nrows_ncols = (1,3), axes_pad = .25,
                    cbar_location = "right", cbar_mode="single")
middle_row = ImageGrid(fig, 312, nrows_ncols = (1,3), axes_pad = .25,
                       cbar_location = "right", cbar_mode="single")
bottom_row = ImageGrid(fig, 313, nrows_ncols = (1,3), axes_pad = .25,
                       cbar_location = "right", cbar_mode="single")

# plot the images            
for i in range(3):
    vmin, vmax = image1.min(),image1.max()
    ax = top_row[i]
    im1 = ax.imshow(image1, vmin=vmin, vmax=vmax)

for i in range(3):
    vmin, vmax = image2.min(),image2.max()
    ax =middle_row[i]
    im2 = ax.imshow(image2, vmin=vmin, vmax=vmax)

# Update showing how to use identical scale across all 3 images
# make some slightly different images and get their bounds
image2s = [image2,image2 + 5,image2 - 5]

# inelegant way to get the absolute upper and lower bounds from the three images
i_max, i_min = 0,0
for im in image2s:
    if im.max() > i_max: 
        i_max= im.max()
    if im.min() < i_min: 
        i_min = im.min()
# plot these as you would the others, but use identical vmin and vmax for all three plots
for i,im in enumerate(image2s):
    ax = bottom_row[i]
    im2_scaled = ax.imshow(im, vmin = i_min, vmax = i_max)

# add your colorbars
cbar1 = top_row.cbar_axes[0].colorbar(im1)
middle_row.cbar_axes[0].colorbar(im2)       
bottom_row.cbar_axes[0].colorbar(im2_scaled)

# example of titling colorbar1
cbar1.set_label_text("label"))

# readjust figure margins after adding colorbars, 
# left and right are unequal because of how
# colorbar labels don't appear to factor in to the adjustment
plt.subplots_adjust(left=0.075, right=0.9)

plt.show()
Community
  • 1
  • 1
  • Superb answer! This produces exactly what I was looking for. Thanks for pointing out my `i==2` mistake as well, I've been editing some Fortran code lately and got mixed up. Question: if, for the three images in a given row, the data do not span quite the same range in values, to make sure the color scheme is consistent across all three images could I find the minimum `.min()` and maximum `.max()` values across all three images and use those values in all three calls to the `ImageGrid()` function? – NeutronStar Jan 11 '16 at 23:10
  • Good question. Yes you can, just make sure you pass the min and max to the `vmin` and `vmax` parameters of `imshow`. I've edited my answer with a a third ImageGrid row to reflect this. – Patrick O'Connor Jan 11 '16 at 23:38
  • ImageGrid() appears to center the axes inside the figure. This means that, when the color bar is added, the axes + color bar combo is no longer centered inside the figure. Any way to address this? It would be nice to have the whole thing centered. Also, how do I give the color bars labels? – NeutronStar Jan 12 '16 at 19:13
  • Huh, that's a weird one. Sometimes matplotlib seems like black magic. Not sure what the "correct" way is to do it, but you can [manually readjust the figure borders](http://stackoverflow.com/questions/4042192/reduce-left-and-right-margins-in-matplotlib-plot) to recenter everything using `plt.subplots_adjust` (i've edited the answer using this) or `plt.tight_layout` which throws an error in my answer but still appears to work (?). Glad I dug into this, I've been having issues with margins in some maps I've been batch generating and this should fix it. – Patrick O'Connor Jan 13 '16 at 06:00