0

Suppose the below code is used to create my plot:

x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)

fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(x, y)
axs[0, 0].set_title('Axis [0, 0]')
axs[0, 1].plot(x, y, 'tab:orange')
axs[0, 1].set_title('Axis [0, 1]')
axs[1, 0].plot(x, -y, 'tab:green')
axs[1, 0].set_title('Axis [1, 0]')
axs[1, 1].plot(x, -y, 'tab:red')
axs[1, 1].set_title('Axis [1, 1]')

for ax in axs.flat:
    ax.set(xlabel='x-label', ylabel='y-label')

# Hide xlabels & and tick labels for top plots.
for ax in axs.flat:
    ax.label_outer()

Figure:

Enter image description here

How do I add a common legend to each column, for example left column time-domain right one frequency-domain?

Peter Mortensen
  • 30,738
  • 21
  • 105
  • 131
Amina Umar
  • 502
  • 1
  • 9
  • This is essentially a duplicate of https://stackoverflow.com/questions/9834452/how-do-i-make-a-single-legend-for-many-subplots, but you need to keep track of your handles for each column separately and pass them to each legend. – Jody Klymak Jan 18 '23 at 02:20
  • @JodyKlymak that didn't answer my question. – Amina Umar Jan 18 '23 at 10:09
  • 1
    And my post below, did it answer your question? – gboffi Jan 20 '23 at 13:42

1 Answers1

1

enter image description here

The code below is based on my answer to the question that Jody linked in a comment to your question. Please reference my other answer for an exaustive explanation of the code below.

import numpy as np
from matplotlib import pyplot as plt

x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)

fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(x, y, label='curve 0 0')
axs[0, 0].set_title('Axis [0, 0]')
axs[0, 1].plot(x, y, 'tab:orange', label='curve 0 1')
axs[0, 1].set_title('Axis [0, 1]')
axs[1, 0].plot(x, -y, 'tab:green', label='curve 1 0')
axs[1, 0].set_title('Axis [1, 0]')
axs[1, 1].plot(x, -y, 'tab:red', label='curve 1 1')
axs[1, 1].set_title('Axis [1, 1]')

for ax in axs.flat:
    ax.set(xlabel='x-label', ylabel='y-label')

# Hide xlabels & and tick labels for top plots.
for ax in axs.flat:
    ax.label_outer()
for col in (0, 1):
    lines_labels = [ax.get_legend_handles_labels() for ax in axs[:,col]]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    axs[0,col].legend(lines, labels, title=['Time', 'Frequency'][col]+' Domain')
gboffi
  • 22,939
  • 8
  • 54
  • 85