0

I am trying to generate a simple plot of the Iris dataset, plotting sepal length against sepal width. I want to create a colobar next to this figure and label the colobar with ONLY 0, 1 and 2. I succeed in generating the plot and the colorbar, but the labelling of the colorbar doesn't work the way I want it to:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets


# import the Iris-dataset from sklearn
iris = datasets.load_iris()
data = np.array(iris.data)
labels = np.array(iris.target)


fig = plt.figure()
img = plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='rainbow', edgecolor='k', s=50)

plt.xlabel("Sepal Length [cm]", fontname="Calibri", fontsize=14, labelpad=14)
plt.ylabel("Sepal Width [cm]", fontname="Calibri", fontsize=14, labelpad=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

cax = fig.add_axes([0.93, 0.09, 0.02, 0.8])
cax.tick_params(axis='y', labelsize=12)
cax.set_yticks(ticks=[0, 1, 2], minor=False)
cax.set_yticklabels(labels=['0', '1', '2'], minor=False)

fig.colorbar(img, cax=cax, orientation="vertical")
plt.show()

enter image description here

Can somebody explain to me how to get rid of the labels that I don't want? I only want to see the 0, 1 and 2 on my colorbar.

Luk
  • 1,009
  • 2
  • 15
  • 33
  • Does this answer your question? [Matplotlib discrete colorbar](https://stackoverflow.com/questions/14777066/matplotlib-discrete-colorbar) – Alex Jun 06 '21 at 20:15

1 Answers1

1

I edited your code a bit according to this documentation and it did the trick:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from pylab import cm

# import the Iris-dataset from sklearn
iris = datasets.load_iris()
data = np.array(iris.data)
labels = np.array(iris.target)


fig = plt.figure()
img = plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='rainbow', edgecolor='k', s=50)

plt.xlabel("Sepal Length [cm]", fontname="Calibri", fontsize=14, labelpad=14)
plt.ylabel("Sepal Width [cm]", fontname="Calibri", fontsize=14, labelpad=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

cax = fig.add_axes([0.93, 0.09, 0.02, 0.8])
cax.tick_params(axis='y', labelsize=12)
# cax.set_yticks(ticks=[0, 1, 2], minor=False)
# cax.set_yticklabels(labels=['0', '1', '2'], minor=False)

cbar = fig.colorbar(img, cax=cax, ticks=[0, 1, 2])
cbar.ax.set_yticklabels(['0', '1', '2'])

# fig.colorbar(img, cax=cax, orientation="vertical")
plt.show()

resulting plot:

outputplot

YevKad
  • 650
  • 6
  • 13