2

seaborn 0.11.0 and matplotlib 3.3.3

write the following code and you get the legend x as blue, y as orange. uncomment the last line before plt.show() now the legends are reversed even though I specified the hue_order in sns.kdeplot().

# random data for testing
x = np.random.randn(100)
y = np.random.randn(100)
df = pd.DataFrame({"x": x, "y": y})
# long format
df = df.melt()
# plot
ax = sns.kdeplot(x="value", hue="variable",
                 hue_order=["x", "y"], data=df)
print(ax.get_legend_handles_labels())
print(ax.legend_)
print(type(ax))
# ax.legend(["x", "y"])
plt.show()

Notice also that the ax itself has no handles nor legends, the print statements print:

([], [])
Legend
<class 'matplotlib.axes._subplots.AxesSubplot'>

I could not figure out how to handle legend produced by seaborn so that I can modify it as I want.

I have a figure with 10 axes and I want to plot with sns.kdeplot() on these axes but with one legend to the figure. I have to use a helper function where we set legend=True for kdeplot() . This a hard constraint.

What I thought will work, is that I remove the the legend from all subplots and add them manually to the figure. I thought this should work because I specified the hue_order in kdeplot(), but it did not a you see from above example.

I also tried to delete all legends except for one subplot and place it relative to the figure, but I couldn't find a way to do so.

Ahmed Elashry
  • 389
  • 2
  • 12

1 Answers1

2

You can create a Figure legend from the Axes legends and then remove them:

import seaborn as sns
import numpy as np
import pandas as pd

# random data for testing
np.random.seed(42)
x = np.random.randn(100)
y = np.random.randn(100)
df = pd.DataFrame({"x": x, "y": y})
# long format
df = df.melt()
# plot
fig,(ax1,ax2) = sns.mpl.pyplot.subplots(2)
sns.kdeplot(x="value", hue="variable",
                 hue_order=["x", "y"], data=df, ax=ax1)
sns.kdeplot(x="value", hue="variable",
                 hue_order=["x", "y"], data=df.assign(value=df.value*2), ax=ax2)

fig.legend(ax.get_legend().get_lines(),
           [t.get_text() for t in ax.get_legend().get_texts()],
           title='variable')
ax1.get_legend().remove()
ax2.get_legend().remove()

enter image description here

Stef
  • 28,728
  • 2
  • 24
  • 52