2

Based on the question and answer here (Line plus shaded region for error band in matplotlib's legend and similar to Combined legend entry for plot and fill_between) I was able to create a legend entry which combines a line and patch elements.

In my use-case I need to plot multiple of these. When I do, I only get a legend entry for the last line+patch combination.

import numpy as np
import matplotlib.pyplot as plt

def plot(x, y, ax, col, group, **kwargs):
    hline, = ax.plot(x, y, 'k--', color=col)
    hpatch = ax.fill_between(x, y+10, y-10, color=col, alpha=0.5)
    ax.legend([(hline, hpatch)], [f"group {group}: Mean + interval"])

fig, ax = plt.subplots()
x = np.linspace(1, 100, 100)
plot(x, x, ax, "C0", 1)
plot(x, x+30, ax, "C1", 2)
plot(x, x+60, ax, "C2", 3)

enter image description here

Note the presence of only the final (group 3) entry in the legend.

Is there a way to get all line/path groups included in the legend so that (in this case) there are 3 items in the legend?

Bonus points if this can be handled entirely within the plot function, avoiding having to pass out handles from the plot function.

This question is not asking about multiple separate legends.

Ben Vincent
  • 361
  • 1
  • 2
  • 11

3 Answers3

3

For something to appear in the legend, matplotlib uses the label= keyword. You can use ax.get_legend_handles_labels() to find all the elements that have a label. Recombining these handles and labels can create your desired legend. Calling ax.legend multiple times will erase the old legend and set a new one.

The test code replaces ax.plot(x, y, 'k--', ... by ax.plot(x, y, '--', .... Note that here k would color the line black, but the color is already set by the color= keyword.

import matplotlib.pyplot as plt
import numpy as np

def plot(x, y, ax, col, group, **kwargs):
    ax.plot(x, y, '--', color=col, label=f"group {group}: Mean")
    ax.fill_between(x, y + 10, y - 10, color=col, alpha=0.5, label='interval')
    handles, labels = ax.get_legend_handles_labels()
    print(handles, labels)
    ax.legend(handles=[(h1, h2) for h1, h2 in zip(handles[::2], handles[1::2])],
              labels=[l1 + " + " +l2 for l1, l2 in zip(labels[::2], labels[1::2])])

fig, ax = plt.subplots()
x = np.linspace(1, 100, 100)
plot(x, x, ax, "C0", 1)
plot(x, x + 30, ax, "C1", 2)
plot(x, x + 60, ax, "C2", 3)

plt.tight_layout()
plt.show()

plot with extending the legend

JohanC
  • 71,591
  • 8
  • 33
  • 66
  • This is a good solution. Only minor quibble is that if there are other elements plotted on the figure then the indexing in the handles and labels breaks. – Ben Vincent Dec 27 '22 at 16:01
  • Normally, you can access the existing legend entries via `ax.legend_.legendHandles` but that doesn't seem to work with the tupples you're using. – JohanC Dec 27 '22 at 16:47
  • There's a question here about that https://discourse.matplotlib.org/t/retrieving-handles-from-legend-returns-line-and-not-tuple/23422/1 – Ben Vincent Dec 29 '22 at 20:12
  • If your figure is more complicated, then you need to track the handles and labels manually by creating the list as you create the artists. – Jody Klymak Dec 29 '22 at 20:56
2

You could return the combined artist and the label from the plot-function and use those to create the legend at the end.

def plot(x, y, ax, col, group, **kwargs):
    hline, = ax.plot(x, y, 'k--', color=col)
    hpatch = ax.fill_between(x, y+10, y-10, color=col, alpha=0.5)
    return (hline, hpatch), f"group {group}: Mean + interval"

fig, ax = plt.subplots()
x = np.linspace(1, 100, 100)

l1 = plot(x, x, ax, "C0", 1)
l2 = plot(x, x+30, ax, "C1", 2)
l3 = plot(x, x+60, ax, "C2", 3)

ax.legend(*zip(*[l1,l2,l3]))
Rutger Kassies
  • 61,630
  • 17
  • 112
  • 97
1

I tried the following, thinking it would work

def plot(x, y, ax, col, group, **kwargs):  
    legend = ax.get_legend() or []
    if legend:
        handles = legend.legendHandles
        labels = [txt.get_text() for txt in legend.get_texts()]
    else:
        handles = []
        labels = []
    hline, = ax.plot(x, y, ls='--', color=col)
    hpatch = ax.fill_between(x, y+10, y-10, color=col, alpha=0.5)
    
    handles += [(hline, hpatch)]
    labels += [f"group {group}: Mean + interval"]

    ax.legend(handles, labels)


fig, ax = plt.subplots(figsize=(9, 6))
x = np.linspace(1, 100, 100)
plot(x, x, ax, "C0", 1)
plot(x, x + 30, ax, "C1", 2)
plot(x, x + 60, ax, "C2", 3)

However, I get this

If we have a look at the handles, it looks like it only shows the lines

ax.get_legend().legendHandles
[<matplotlib.lines.Line2D at 0x7f066b54b370>,
 <matplotlib.lines.Line2D at 0x7f066b54b940>,
 <matplotlib.lines.Line2D at 0x7f066b54bdc0>]

I was trying to see if this is a bug or intended behavior.

I also tried

from matplotlib.lines import Line2D
from matplotlib.patches import Patch

def plot(x, y, ax, col, group, **kwargs):  
    legend = ax.get_legend() or []
    if legend:
        handles = legend.legendHandles
        labels = [txt.get_text() for txt in legend.get_texts()]
    else:
        handles = []
        labels = []
    hline, = ax.plot(x, y, ls='--', color=col)
    hpatch = ax.fill_between(x, y+10, y-10, color=col, alpha=0.5)
    label = f"group {group}: Mean + interval"
    handle = (Line2D([], [], color=col, label=label), Patch(color=col, alpha=0.5, label=label))
    handles += [handle]
    labels += [label]

    ax.legend(handles, labels)

but the result is still the same

Tomas Capretto
  • 721
  • 5
  • 6