The exercise is to create a scatter plot for height vs weight, colored by gender. Data is in a pandas DataFrame; Height, Weight, and Gender are columns.
I've got the plot and the colors. I'm trying to get the legend to work.
I'm so close, but not there. Can someone tell me what piece I'm missing? I'm getting a legend but for first category and the second color, which is an unexpected combination.
Disclaimer:
I know I can use seaborn for this and it "just works". But I want to understand how to do it with matplotlib.
I also realize I could do two scatter plots on the same axes, after separating by category. But, again, I'd like to understand how to do this without separating the data. I can plot both categories with separate colors without doing two plots. I should be able to get the legend for both without doing two plots. Matpyplot can't be that stupid.
with thanks for the pointer to theoretical duplicate which I've already seen, and tried, but
plot.legend_elements()
returns([], [])
, so that's not really helpful.
This version:
def dfScatter(df, xcol, ycol, catcol):
fig, ax = plt.subplots()
categories = np.unique(df[catcol])
colors = ['orange', 'blue', 'g', 'r', 'c', 'm', 'y', 'purple', ]
colordict = dict(zip(categories, colors))
df["Color"] = df[catcol].apply(lambda x: colordict[x])
plot = ax.scatter(df[xcol], df[ycol], c=df.Color, alpha='0.7')
ax.legend(colordict)
return fig
fig = dfScatter(dflog, 'Weight', 'Height', 'Gender')
plt.show()
gives me this plot (with a legend with one item):
Oddly enough, it's the wrong item; blue is the color for Male.
It would be convenient if I could break up the dictionary in some way to convince the legend to give me all of the pieces, correctly mapped. The dictionary content is {'Female': 'orange', 'Male': 'blue'}
This post comment has some very promising code:
from matplotlib.colors import ListedColormap
x = [1, 3, 4, 6, 7, 9]
y = [0, 0, 5, 8, 8, 8]
classes = ['A', 'B', 'C']
values = [0, 0, 1, 2, 2, 2]
colours = ListedColormap(['r','b','g'])
scatter = plt.scatter(x, y,c=values, cmap=colours)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
Which works beautifully for me but I can't seem to make it work for my use case. I'm probably missing something obvious, but
def dfScatter(df, xcol, ycol, catcol):
categories = np.unique(df[catcol])
colors = ['orange', 'blue', 'g', 'r', 'c', 'm', 'y', 'purple', ]
colordict = dict(zip(categories, colors))
cat_colors = df[catcol].apply(lambda x: colordict[x])
scatter = plt.scatter(df[xcol], df[ycol], c=cat_colors, alpha='0.7')
plt.legend(handles=scatter.legend_elements()[0], labels=categories)
return fig
fig = dfScatter(dflog, 'Weight', 'Height', 'Gender')
plt.show()
tells me I have "No handles with labels found to put in legend." and scatter.legend_elements()[0]
is empty.