0

I am using plt.subplots to create a figure where every column corresponds to a different set of parameters (or different tool/dataset in more realistic situation), but the rows correspond to different characteristics shown for the same parameters/tool/dataset. That is the legends for all plots in a row are the same, but they are different in different rows.

Questions:

  1. I would like to have one legend for every row (or possibly column, if I decide to turn the plot)
  2. I would like to place the legend outside of the subplots (on the right.)

My pedestrian solution is to plot the legend only for the last subplot in every row. I probably could also build upon it to address the second question. However, I wonder, if there is a more elegant way to do this, as, e.g., using fig.legend() to create a single legend for the whole figure (see the thread.)

Here is an example that I cooked up to facilitate the discussion:

import numpy as np
import matplotlib.pyplot as plt

xx = np.linspace(0., 2.*np.pi)
ff1 = [np.cos(n*xx) for n in np.arange(1, 5)] 
ff2 = [np.sin(n*xx) for n in np.arange(1, 5)]
ff3 = [xx**2 / n for n in np.arange(1, 5)]

fig, axs = plt.subplots(2, 4, figsize=(12, 6))
for ax, f1, f2, f3 in zip(axs[0], ff1, ff2, ff3):
    ax.plot(xx, f1, label="cos")
    ax.plot(xx, f2, label="sin")
    ax.plot(xx, f3, label="x^2")
    ax.set_xlabel("x")
    ax.set_ylabel("f1, f2, f3")
    # ax.legend()
axs[0,-1].legend()

for ax, f1, f2, f3 in zip(axs[1], ff1, ff2, ff3):
    ax.plot(f3, f1, label="cos vs. x^2")
    ax.plot(f3, f2, label="sin vx. x^2")
    ax.set_xlabel("f3")
    ax.set_ylabel("f1, f2")
    # ax.legend()   
axs[1,-1].legend()

plt.tight_layout()
plt.show()

enter image description here

Roger Vadim
  • 373
  • 2
  • 12

0 Answers0