5

I am trying to create a plot containing 8 subplots (4 rows and 2 columns). To do so, I have made this code that reads the x and y data and plots it in the following fashion:

fig, axs = plt.subplots(4, 2, figsize=(15,25))
y_labels = ['k0', 'k1']

for x in range(4):
    for y in range(2):
        axs[x, y].scatter([i[x] for i in X_vals], [i[y] for i in y_vals])
        axs[x, y].set_xlabel('Loss')
        axs[x, y].set_ylabel(y_labels[y])

This gives me the following result: enter image description here

However, I want to add a title to all the rows (not the plots) in the following way(the titles in yellow text):

enter image description here

I found this image and some ways to do that here but I wasn't able to implement this for my use case and got an error. This is what I tried :

gridspec = axs[0].get_subplotspec().get_gridspec()
subfigs = [fig.add_subfigure(gs) for gs in gridspec]

for row, subfig in enumerate(subfigs):
    subfig.suptitle(f'Subplot row title {row}')

which gave me the error : 'numpy.ndarray' object has no attribute 'get_subplotspec'

So I changed the code to :

gridspec = axs[0, 0].get_subplotspec().get_gridspec()
    subfigs = [fig.add_subfigure(gs) for gs in gridspec]
    
    for row, subfig in enumerate(subfigs):
        subfig.suptitle(f'Subplot row title {row}')

but this returned the error : 'Figure' object has no attribute 'add_subfigure'

Ravish Jha
  • 481
  • 3
  • 25

1 Answers1

4

The solution in the answer that you linked is the correct one, however it is specific for the 3x3 case as shown there. The following code should be a more general solution for different numbers of subplots. This should work provided your data and y_label arrays/lists are all the correct size.

Note that this requires matplotlib 3.4.0 and above to work:

import numpy as np
import matplotlib.pyplot as plt

# random data. Make sure these are the correct size if changing number of subplots
x_vals = np.random.rand(4, 10)
y_vals = np.random.rand(2, 10)
y_labels = ['k0', 'k1']  

# change rows/cols accordingly
rows = 4
cols = 2

fig = plt.figure(figsize=(15,25), constrained_layout=True)
fig.suptitle('Figure title')

# create rows x 1 subfigs
subfigs = fig.subfigures(nrows=rows, ncols=1)

for row, subfig in enumerate(subfigs):
    subfig.suptitle(f'Subplot row title {row}')

    # create 1 x cols subplots per subfig
    axs = subfig.subplots(nrows=1, ncols=cols)
    for col, ax in enumerate(axs):
        ax.scatter(x_vals[row], y_vals[col])
        ax.set_title("Subplot ax title")
        ax.set_xlabel('Loss')
        ax.set_ylabel(y_labels[col])

Which gives:

enter image description here

DavidG
  • 24,279
  • 14
  • 89
  • 82
  • Got it,I was confused about the subfigures a bit and also updated matplotlib. Thanks. Moreover, for my use case I had to change the line `ax.scatter(x_vals[row], y_vals[col])` to `ax.scatter([i[row] for i in X_vals], [i[col] for i in y_vals])` – Ravish Jha Sep 02 '21 at 11:17