1

I have this code

import pandas as pd
import matplotlib.pyplot as plt

dataset = pd.read_csv('iris.csv', header=None, names=['sepal length', 'sepal width', 'petal length', 'petal width', 'class'])

dataset.drop(index=dataset.index[dataset['class'] == 'Iris-setosa'], inplace=True)

dataset.drop('petal width', axis='columns', inplace=True)
dataset.drop('sepal width', axis='columns', inplace=True)


fig, ax = plt.subplots()

groups = dataset.groupby('class')

colors = {'Iris-versicolor': 'blue', 'Iris-virginica': 'red'}

ax.scatter(dataset['petal length'], dataset['sepal length'], c=dataset['class'].map(colors))

plt.legend(['aaa', 'bbb'])

plt.show()

The data set is from this link: https://archive.ics.uci.edu/ml/datasets/Iris

I get a colored scatterplot, color based on 'class'. But in the legend it only shows one color: 'aaa'.

How do I make the put a legend on scatter with pyplotlib showing multiple colors?

Coder88
  • 1,015
  • 3
  • 10
  • 23

1 Answers1

1

Turning the plotting code into a for loop over the groups works for me:

import pandas as pd
import matplotlib.pyplot as plt

# Removed data wrangling code

fig, ax = plt.subplots()

for name, group in dataset.groupby('class'):
    ax.scatter(
        group['petal length'], group['sepal length'], label=name)

plt.legend()

Output

scatter with legend

TC Arlen
  • 1,442
  • 2
  • 11
  • 19