I'm really confused about matplotlib in general. I normally just use import matplotlib.pyplot as plt.
And then do everything like plt.figure(), plt.scatter(), plt.xlabel(), plt.show() etc. But then I google how to do something like, map the legend with a colour and I get all these examples that include ax. But there is plt.legend() and the example in matplotlib documentation just shows plt.legend(handles) but doesn't show you what handles is supposed to be. And if I want to do the ax thing then I have to re-write all my code cause I wanted to use plt since it's simpler.
Here's my code:
import matplotlib.pyplot as plt
colmap = {
"domestic": "blue",
"cheetah": "red",
"leopard": "green",
"tiger": "black"
}
colours = []
for i in y_train:
colours.append(colmap[i])
plt.figure(figsize= [15,5])
plt.scatter(X_train[:,0], X_train[:,2],c=colours)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()
Now I want to add a legend that just shows the colours the same as it is in my dictionary. But if I do:
plt.legend(["domestic","cheetah","leopard","tiger"])
it only shows "domestic" in the legend and the colour is red which doesn't actually match how I've colour coded it. Is there a way to do this without re-writing everything with the "ax" thing? And if not, how do I adapt this to ax? Do I just write ax = plt.scatter(....)?