I have the following code for visualizing decision boundary for a binary classification problem:
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
X_train, X_test, y_train, y_test = train_test_split(X_C2, y_C2,
random_state=0)
n_neighbors = [1, 3, 11]
weights = 'uniform'
h = .02 # step size in the mesh
# Create color maps (http://htmlcolorcodes.com/)
cmap_light = ListedColormap(['#F9F999', '#F3F3F3'])
cmap_bold = ListedColormap(['#FFFF00', '#000000'])
# For example, FF AA AA = RGB(255, 170, 170)
for n in n_neighbors:
# we create an instance of Neighbours Classifier and fit the data.
clf = neighbors.KNeighborsClassifier(n, weights=weights)
clf.fit(X_train, y_train)
# ----------------------------- Mesh color i.e. background color begins ------------------------
# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure(figsize=(7,5), dpi=100)
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
# ------------------------------ Mesh color i.e. background color ends -------------------------
# Plot also the training points
# plt.figure(figsize=(10,4), dpi=80)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cmap_bold, s=40, label = "class 0") # scatter plot of height vs width
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cmap_bold, s=40, label = "class 1") # scatter plot of height vs width
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("2-Class classification (k = %i, weights = '%s')" % (n, weights))
plt.legend()
plt.show()
The best I could do is:
But the legend doesn't correspond to the actual class of the points. The X_train has 2 columns one for each feature. The y_train contains the label for this data namely 0's and 1's. How to obtain a legend that has the same color as the points?