2

I have made a simple scatterplot using matplotlib showing data from 2 numerical variables (varA and varB) with colors that I defined with a 3rd categorical string variable (col) containing 10 unique colors (corresponding to another string variable with 10 unique names), all in the same Pandas DataFrame with 100+ rows. Is there an easy way to create a legend for this scatterplot that shows the unique colored dots and their corresponding category names? Or should I somehow group the data and plot each category in a subplot to do this? This is what I have so far:

import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

varA = df['A']
varB = df['B'] 
col = df['Color']

plt.scatter(varA,varB, c=col, alpha=0.8)
plt.legend()

plt.show()
galmeriol
  • 461
  • 4
  • 14
Mary
  • 23
  • 1
  • 4
  • See [matplotlib-scatterplot-with-legend](https://stackoverflow.com/questions/42056713/matplotlib-scatterplot-with-legend), [setting-a-legend-matching-the-colours-in-pyplot-scatter](https://stackoverflow.com/questions/44164111/setting-a-legend-matching-the-colours-in-pyplot-scatter) – ImportanceOfBeingErnest May 08 '18 at 18:26

3 Answers3

4

I had to chime in, because I could not accept that I needed a for-loop to accomplish this. It just seems really annoying and unpythonic - especially when I'm not using Pandas. However, after some searching, I found the answer. You just need to import the 'collections' package so that you can access the PathCollections class and specifically, the legend_elements() method. See implementation below:

# imports
import matplotlib.collections
import numpy as np

# create random data and numerical labels
x = np.random.rand(10,2)
y = np.random.randint(4, size=10)

# create list of categories
labels = ['type1', 'type2', 'type3', 'type4']

# plot
fig, ax = plt.subplots()
scatter = ax.scatter(x[:,0], x[:,1], c=y)
handles, _ = scatter.legend_elements(prop="colors", alpha=0.6) # use my own labels
legend1 = ax.legend(handles, labels, loc="upper right")
ax.add_artist(legend1)
plt.show()

scatterplot legend with custom labels

Source:

https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_with_legend.html

https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements

James
  • 41
  • 2
1

Considering, Color is the column that has all the colors and labels, you can simply do following.

colors = list(df['Color'].unique())
for i in range(0 , len(colors)):
    data = df.loc[df['Color'] == colors[i]]
    plt.scatter('A', 'B', data=data, color='Color', label=colors[i])
plt.legend()
plt.show()
harvpan
  • 8,571
  • 2
  • 18
  • 36
  • 1
    Thank you, this works great! And inspired by your answer I added a line to change the color names into the corresponding variable names: names = list(df['Names'].unique()) , and then: plt.scatter('A', 'B', data=data, color='Color', label=names[i]) – Mary May 08 '18 at 18:19
  • In how far is this different from existing answers? If it's not, please mark as duplicate, if it is, better give the answer to already existing questions but at least state what is different here for future readers to understand the difference. Otherwise one ends up with hundreds of different results when searching for "scatter legend" or similar. – ImportanceOfBeingErnest May 08 '18 at 18:32
  • @ImportanceOfBeingErnest you are right. It is not a novel way. I found [this](https://stackoverflow.com/questions/26558816/matplotlib-scatter-plot-with-legend/26559256?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa) question with much similarity. But none of the answers are accepted. – harvpan May 08 '18 at 18:40
  • @ImportanceOfBeingErnest Seeing the answers to my question and your links I agree there's great similarity, but I didn't find these posts when searching for an answer (maybe because my search didn't include "classes") and also didn't know how to handle this with columns from a Pandas dataframe. So thanks again for taking the time to read my question and write a reply. – Mary May 08 '18 at 18:57
  • @ImportanceOfBeingErnest Just saw your comment, which was not there when I answered. [this](https://stackoverflow.com/questions/42056713/matplotlib-scatterplot-with-legend) is quit close to what have done. It did not appear in my search as well. I try my best to maintain the best etiquette and avoid redundancy to this community. – harvpan May 08 '18 at 19:02
0

A simple way is to group your data by color, then plot all of the data on one plot. Pandas has a built in groupby function. For example:

import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

for color, group in df.groupby(['Color']):
    plt.scatter(group['A'], group['B'], c=color, alpha=0.8, label=color)

plt.legend()
plt.show()

Notice that we call plt.scatter once for each grouping of data. Then we only need to call plt.legend and plt.show once all of the data is in our plot.

SNygard
  • 916
  • 1
  • 9
  • 21