460

I have a plot with two y-axes, using twinx(). I also give labels to the lines, and want to show them with legend(), but I only succeed to get the labels of one axis in the legend:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(time, Swdown, '-', label = 'Swdown')
ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
ax2.plot(time, temp, '-r', label = 'temp')
ax.legend(loc=0)
ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()

So I only get the labels of the first axis in the legend, and not the label 'temp' of the second axis. How could I add this third label to the legend?

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
joris
  • 133,120
  • 36
  • 247
  • 202
  • 24
    [*Don't do this in anywhere remotely close to any production code*] When my only aim is to generate a beautiful plot with the appropiate legend ASAP, I use an ugly hack of plotting an empty array on `ax` with the style I use on `ax2`: in your case, `ax.plot([], [], '-r', label = 'temp')`. It's much faster and simpler than doing it properly... – Neinstein Apr 19 '18 at 12:36
  • Also see https://stackoverflow.com/a/57484812/3642162 for pandas and twinx – ijuneja Jun 30 '21 at 10:20
  • The legend will be merged properly if you comment out the line `ax.legend(loc=0)`. A simple and natural alternative that preserves the default merged legend without having to tweak is to replace that line with `fig.legend(loc=0)` instead. As explained in the answer by @ImportanceOfBeingErnest below, the legend with multiple axes belong to the figure `fig`, rather than to the left axis `ax`. In retrospect, it should be obvious that `ax.legend()` will mess things up. (I don't have your data to check your particular case, but this is what I've observed on other data) – PatrickT Jan 14 '22 at 11:49
  • In case you use subplots, check out the answer by Suuuehgi below, it's the most elegant to me. – DanT Jul 08 '22 at 11:13

11 Answers11

556

You can easily add a second legend by adding the line:

ax2.legend(loc=0)

You'll get this:

enter image description here

But if you want all labels on one legend then you should do something like this:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')

time = np.arange(10)
temp = np.random.random(10)*30
Swdown = np.random.random(10)*100-10
Rn = np.random.random(10)*100-10

fig = plt.figure()
ax = fig.add_subplot(111)

lns1 = ax.plot(time, Swdown, '-', label = 'Swdown')
lns2 = ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
lns3 = ax2.plot(time, temp, '-r', label = 'temp')

# added these three lines
lns = lns1+lns2+lns3
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)

ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()

Which will give you this:

enter image description here

Paul
  • 42,322
  • 15
  • 106
  • 123
  • 11
    This fails with `errorbar` plots. For a solution that correctly handles them, see below: http://stackoverflow.com/a/10129461/1319447 – Davide Nov 17 '15 at 15:02
  • 2
    To prevent two overlapping legends as in my case where I specifed two .legend(loc=0), you should specify two different values for the legend location value (both other than 0). See: http://matplotlib.org/api/legend_api.html – Roalt Jan 04 '16 at 14:12
  • 3
    I had some trouble adding a single line to some subplot with multiple lines `ax1`. In this case use `lns1=ax1.lines` and then append `lns2` to this list. – Little Bobby Tables Jul 12 '17 at 12:01
  • The different values used by `loc` are explained [here](https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.legend) – Dror Jul 18 '17 at 09:46
  • 2
    See the answer below for a more automatic way (with matplotlib >= 2.1): https://stackoverflow.com/a/47370214/653364 – joris Nov 20 '17 at 15:42
  • How do you line up the tick marks for both y axis? In the above example, the '0' on the left yticklabels should be in line with '5' on the right yticklabels, etc. How to adjust it? `matplotlib` is so painful for beginners. :( – StayFoolish Dec 28 '17 at 01:52
  • Could you tell me in the `ax.ylabel` value, why is there two `\,` in the label? I tried to remove it and display the results, it looks the same to me. – StayFoolish Dec 28 '17 at 02:20
  • Doing that works better: lines, labels = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() lines3, labels3 = ax3.get_legend_handles_labels() ax3.legend(lines + lines2 + lines3, labels + labels2 + labels3) Cf. https://stackoverflow.com/a/10129461/7115301 – belka Feb 14 '19 at 09:27
  • There's a good answer in https://stackoverflow.com/questions/14344063/single-legend-for-multiple-axes/41030990 – Roel Verhoeven Mar 27 '19 at 10:10
  • This method works if you plots are similar. In my case I had one line plot and one scatter and @zgana answer worked for me – Sabzaliev Shukur Oct 27 '22 at 06:42
  • Doesn't work for scatter either – FreelanceConsultant Apr 01 '23 at 18:48
285

I'm not sure if this functionality is new, but you can also use the get_legend_handles_labels() method rather than keeping track of lines and labels yourself:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')

pi = np.pi

# fake data
time = np.linspace (0, 25, 50)
temp = 50 / np.sqrt (2 * pi * 3**2) \
        * np.exp (-((time - 13)**2 / (3**2))**2) + 15
Swdown = 400 / np.sqrt (2 * pi * 3**2) * np.exp (-((time - 13)**2 / (3**2))**2)
Rn = Swdown - 10

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(time, Swdown, '-', label = 'Swdown')
ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
ax2.plot(time, temp, '-r', label = 'temp')

# ask matplotlib for the plotted objects and their labels
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc=0)

ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
zgana
  • 2,971
  • 1
  • 15
  • 6
175

From matplotlib version 2.1 onwards, you may use a figure legend. Instead of ax.legend(), which produces a legend with the handles from the axes ax, one can create a figure legend

fig.legend(loc="upper right")

which will gather all handles from all subplots in the figure. Since it is a figure legend, it will be placed at the corner of the figure, and the loc argument is relative to the figure.

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10)
y = np.linspace(0,10)
z = np.sin(x/3)**2*98

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x,y, '-', label = 'Quantity 1')

ax2 = ax.twinx()
ax2.plot(x,z, '-r', label = 'Quantity 2')
fig.legend(loc="upper right")

ax.set_xlabel("x [units]")
ax.set_ylabel(r"Quantity 1")
ax2.set_ylabel(r"Quantity 2")

plt.show()

enter image description here

In order to place the legend back into the axes, one would supply a bbox_to_anchor and a bbox_transform. The latter would be the axes transform of the axes the legend should reside in. The former may be the coordinates of the edge defined by loc given in axes coordinates.

fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax.transAxes)

enter image description here

ImportanceOfBeingErnest
  • 321,279
  • 53
  • 665
  • 712
  • So, version 2.1 already released? But in Anaconda 3, I tried `conda upgrade matplotlib` no newer versions found, I'm still using v.2.0.2 – StayFoolish Dec 28 '17 at 02:29
  • 2
    This does not seem to work when you have many subplots. It adds a single legend for all subplots. One typically needs one legend for each subplot, containing series in both primary and secondary axes in each legend. – sancho.s ReinstateMonicaCellio Dec 19 '19 at 02:18
  • @sancho Correct, that's what is written in the third sentence of this answer, "...which will gather all handles from all subplots in the figure.". – ImportanceOfBeingErnest Dec 19 '19 at 11:28
  • 1
    While the gathered legends works for twin y-axes, it seems when using seaborn the legends from each of the two axes still show up on the plot. The code described [here](https://stackoverflow.com/a/54782540/8508004), along with the addition in the comment by Mihai Cherlau to hide those. For the example above you'd use `ax.legend_.remove();ax2.legend_.remove()`. – Wayne Feb 11 '20 at 03:31
  • Excellent answer. In my case I needed `fig.legend(loc="upper right"); ax.get_legend().remove()`, otherwise two legends were printed. Go figure. :-) – PatrickT Jan 14 '22 at 13:32
  • 1
    This is the best answer, and appreciated. But even *more* pythonic would be if the api worked as expected: `plt.legend()` just working as usual for all labels, and then let people manipulate labels for each axis individually using (something like) `ax.legend()`, in the rare case they would want to do that. – eric Apr 05 '22 at 13:30
  • This works with one subplot (and in this case this is a better solution than the other). With multiple subplots there is a single legend, which might not be what you want. – mins Jan 31 '23 at 11:14
57

You can easily get what you want by adding the line in ax:

ax.plot([], [], '-r', label = 'temp')

or

ax.plot(np.nan, '-r', label = 'temp')

This would plot nothing but add a label to legend of ax.

I think this is a much easier way. It's not necessary to track lines automatically when you have only a few lines in the second axes, as fixing by hand like above would be quite easy. Anyway, it depends on what you need.

The whole code is as below:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')

time = np.arange(22.)
temp = 20*np.random.rand(22)
Swdown = 10*np.random.randn(22)+40
Rn = 40*np.random.rand(22)

fig = plt.figure()
ax = fig.add_subplot(111)
ax2 = ax.twinx()

#---------- look at below -----------

ax.plot(time, Swdown, '-', label = 'Swdown')
ax.plot(time, Rn, '-', label = 'Rn')

ax2.plot(time, temp, '-r')  # The true line in ax2
ax.plot(np.nan, '-r', label = 'temp')  # Make an agent in ax

ax.legend(loc=0)

#---------------done-----------------

ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()

The plot is as below:

enter image description here


Update: add a better version:

ax.plot(np.nan, '-r', label = 'temp')

This will do nothing while plot(0, 0) may change the axis range.


One extra example for scatter

ax.scatter([], [], s=100, label = 'temp')  # Make an agent in ax
ax2.scatter(time, temp, s=10)  # The true scatter in ax2

ax.legend(loc=1, framealpha=1)
Syrtis Major
  • 3,791
  • 1
  • 30
  • 40
19

Preparation

import numpy as np
from matplotlib import pyplot as plt

fig, ax1 = plt.subplots( figsize=(15,6) )

Y1, Y2 = np.random.random((2,100))

ax2 = ax1.twinx()

Content

I'm surprised it did not show up so far but the simplest way is to either collect them manually into one of the axes objs (that lie on top of each other)

l1 = ax1.plot( range(len(Y1)), Y1, label='Label 1' )
l2 = ax2.plot( range(len(Y2)), Y2, label='Label 2', color='orange' )

ax1.legend( handles=l1+l2 )

Plot_axes

or have them collected automatically into the surrounding figure by fig.legend() and fiddle around with the the bbox_to_anchor parameter:

ax1.plot( range(len(Y1)), Y1, label='Label 1' )
ax2.plot( range(len(Y2)), Y2, label='Label 2', color='orange' )

fig.legend( bbox_to_anchor=(.97, .97) )

Plot_figlegend

Finalization

fig.tight_layout()
fig.savefig('stackoverflow.png', bbox_inches='tight')
Suuuehgi
  • 4,547
  • 3
  • 27
  • 32
  • 1
    Thank you for a really clear answer that works! I additionally found that, if you are using `scatter` rather than `plot`, you need to do `handles=[l1+l2]` in the legend call due to them being `PathCollection` objects rather than simple lists. – Xorgon Jun 06 '22 at 14:19
9

A quick hack that may suit your needs..

Take off the frame of the box and manually position the two legends next to each other. Something like this..

ax1.legend(loc = (.75,.1), frameon = False)
ax2.legend( loc = (.75, .05), frameon = False)

Where the loc tuple is left-to-right and bottom-to-top percentages that represent the location in the chart.

user2105997
  • 141
  • 1
  • 5
5

I found an following official matplotlib example that uses host_subplot to display multiple y-axes and all the different labels in one legend. No workaround necessary. Best solution I found so far. http://matplotlib.org/examples/axes_grid/demo_parasite_axes2.html

from mpl_toolkits.axes_grid1 import host_subplot
import mpl_toolkits.axisartist as AA
import matplotlib.pyplot as plt

host = host_subplot(111, axes_class=AA.Axes)
plt.subplots_adjust(right=0.75)

par1 = host.twinx()
par2 = host.twinx()

offset = 60
new_fixed_axis = par2.get_grid_helper().new_fixed_axis
par2.axis["right"] = new_fixed_axis(loc="right",
                                    axes=par2,
                                    offset=(offset, 0))

par2.axis["right"].toggle(all=True)

host.set_xlim(0, 2)
host.set_ylim(0, 2)

host.set_xlabel("Distance")
host.set_ylabel("Density")
par1.set_ylabel("Temperature")
par2.set_ylabel("Velocity")

p1, = host.plot([0, 1, 2], [0, 1, 2], label="Density")
p2, = par1.plot([0, 1, 2], [0, 3, 2], label="Temperature")
p3, = par2.plot([0, 1, 2], [50, 30, 15], label="Velocity")

par1.set_ylim(0, 4)
par2.set_ylim(1, 65)

host.legend()

plt.draw()
plt.show()
gerrit
  • 71
  • 2
  • 4
  • Welcome to Stack Overflow! Please quote the most relevant part of the link, in case the target site is unreachable or goes permanently offline. See [How do I write a good answer](http://stackoverflow.com/help/how-to-answer). Focus on more current questions in the future, this one is nearly 4 years old. – ByteHamster Mar 05 '15 at 15:32
  • 1
    Indeed a good find but I wish you would have taken what you learned from the example, applied it to the OP's MWE, and included an image. – aeroNotAuto Nov 08 '17 at 21:32
3

If you are using Seaborn you can do this:

g = sns.barplot('arguments blah blah')
g2 = sns.lineplot('arguments blah blah')
h1,l1 = g.get_legend_handles_labels()
h2,l2 = g2.get_legend_handles_labels()
#Merging two legends
g.legend(h1+h2, l1+l2, title_fontsize='10')
#removes the second legend
g2.get_legend().remove()
pooria
  • 343
  • 2
  • 14
1

The solutions proposed so far have one or two inconvenients:

  • Handles needs to be collected individually when plotting, e.g. lns1 = ax.plot(time, Swdown, '-', label = 'Swdown'). There is a risk of forgetting handles when updating the code.

  • Legend is drawn for the whole figure, not by subplot, which is likely a no-go if you have multiple subplots.

This new solution takes advantage of Axes.get_legend_handles_labels() to collect existing handles and labels for the main axis and for the twin axis.

Collecting handles and labels automatically

This numpy operation will scan all axes which share the same subplot area than ax, including ax and return merged handles and labels:

hl = np.hstack([axis.get_legend_handles_labels()
                for axis in ax.figure.axes
                if axis.bbox.bounds == ax.bbox.bounds])

It can be used to feed legend() arguments this way:

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(1, 200)
signals = [np.exp(-t/20) * np.cos(t*k) for k in (1, 2)]

fig, axes = plt.subplots(nrows=2, figsize=(10, 3), layout='constrained')
axes = axes.flatten()

for i, (ax, signal) in enumerate(zip(axes, signals)):
    # Plot as usual, no change to the code
    ax.plot(t, signal, label=f'plotted on axes[{i}]', c='C0', lw=9, alpha=0.3)
    ax2 = ax.twinx()
    ax2.plot(t, signal, label=f'plotted on axes[{i}].twinx()', c='C1')

    # The only specificity of the code is when plotting the legend
    h, l = np.hstack([axis.get_legend_handles_labels()
                      for axis in ax.figure.axes
                      if axis.bbox.bounds == ax.bbox.bounds]).tolist()
    ax2.legend(handles=h, labels=l, loc='upper right')

enter image description here

mins
  • 6,478
  • 12
  • 56
  • 75
0

As provided in the example from matplotlib.org, a clean way to implement a single legend from multiple axes is with plot handles:

import matplotlib.pyplot as plt


fig, ax = plt.subplots()
fig.subplots_adjust(right=0.75)

twin1 = ax.twinx()
twin2 = ax.twinx()

# Offset the right spine of twin2.  The ticks and label have already been
# placed on the right by twinx above.
twin2.spines.right.set_position(("axes", 1.2))

p1, = ax.plot([0, 1, 2], [0, 1, 2], "b-", label="Density")
p2, = twin1.plot([0, 1, 2], [0, 3, 2], "r-", label="Temperature")
p3, = twin2.plot([0, 1, 2], [50, 30, 15], "g-", label="Velocity")

ax.set_xlim(0, 2)
ax.set_ylim(0, 2)
twin1.set_ylim(0, 4)
twin2.set_ylim(1, 65)

ax.set_xlabel("Distance")
ax.set_ylabel("Density")
twin1.set_ylabel("Temperature")
twin2.set_ylabel("Velocity")

ax.yaxis.label.set_color(p1.get_color())
twin1.yaxis.label.set_color(p2.get_color())
twin2.yaxis.label.set_color(p3.get_color())

tkw = dict(size=4, width=1.5)
ax.tick_params(axis='y', colors=p1.get_color(), **tkw)
twin1.tick_params(axis='y', colors=p2.get_color(), **tkw)
twin2.tick_params(axis='y', colors=p3.get_color(), **tkw)
ax.tick_params(axis='x', **tkw)

ax.legend(handles=[p1, p2, p3])

plt.show()
0

Here is another way to do this:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')

fig = plt.figure()
ax = fig.add_subplot(111)
pl_1, = ax.plot(time, Swdown, '-')
label_1 = 'Swdown'
pl_2, = ax.plot(time, Rn, '-')
label_2 = 'Rn'

ax2 = ax.twinx()
pl_3, = ax2.plot(time, temp, '-r')
label_3 = 'temp'

ax.legend([pl[enter image description here][1]_1, pl_2, pl_3], [label_1, label_2, label_3], loc=0)

ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()

enter image description here

MathPass
  • 3
  • 2