0

I'm really confused about matplotlib in general. I normally just use import matplotlib.pyplot as plt.

And then do everything like plt.figure(), plt.scatter(), plt.xlabel(), plt.show() etc. But then I google how to do something like, map the legend with a colour and I get all these examples that include ax. But there is plt.legend() and the example in matplotlib documentation just shows plt.legend(handles) but doesn't show you what handles is supposed to be. And if I want to do the ax thing then I have to re-write all my code cause I wanted to use plt since it's simpler.

Here's my code:

import matplotlib.pyplot as plt

colmap = {
    "domestic": "blue",
    "cheetah": "red",
    "leopard": "green",
    "tiger": "black"
}

colours = []

for i in y_train:
    colours.append(colmap[i])
    
plt.figure(figsize= [15,5])
plt.scatter(X_train[:,0], X_train[:,2],c=colours)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()

Now I want to add a legend that just shows the colours the same as it is in my dictionary. But if I do:

plt.legend(["domestic","cheetah","leopard","tiger"])

it only shows "domestic" in the legend and the colour is red which doesn't actually match how I've colour coded it. Is there a way to do this without re-writing everything with the "ax" thing? And if not, how do I adapt this to ax? Do I just write ax = plt.scatter(....)?

yami123
  • 19
  • 2
  • Please provide some sample data and your desired output. (See [How to make good reproducible pandas examples](https://stackoverflow.com/questions/20109391/how-to-make-good-reproducible-pandas-examples)). – RoseGod Dec 12 '21 at 19:09

1 Answers1

0

No data was provided but this code can help you unbderstand how to add color to a scatter plot in matplotlib:

import matplotlib.pyplot as plt import numpy as np

# data for scatter plots
x = list(range(0,30))
y = [i**2 for i in x]

# data for mapping class to color
y_train = ['domestic','cheetah', 'cheetah', 'tiger', 'domestic',
           'leopard', 'tiger', 'domestic', 'cheetah', 'domestic',
           'leopard', 'leopard', 'domestic', 'domestic', 'domestic',
           'domestic', 'cheetah', 'tiger', 'cheetah', 'cheetah',
           'domestic', 'domestic', 'domestic', 'cheetah', 'leopard',
           'cheetah', 'domestic', 'cheetah', 'tiger', 'domestic']

# color mapper
colmap = {
    "domestic": "blue",
    "cheetah": "red",
    "leopard": "green",
    "tiger": "black"
}

# create color array
colors = [colmap[i] for i in y_train]

# plot scatter    
plt.figure(figsize=(15,5))
plt.scatter(x, y, c=colors)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()

Output:

Output

RoseGod
  • 1,206
  • 1
  • 9
  • 19