1

I have met a problem for plotting a confusion matrix. The upper and lower lines are put incorrectly. When I plot it, it looks like this.

enter image description here

I think there is nothing wrong with my code, since I took it from this YouTube exactly.

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Graphical analytics
cm = confusion_matrix(train_set.targets, train_preds.argmax(dim=1))
names = ('T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot')
plt.figure(figsize=(10,10))
plot_confusion_matrix(cm, names)

Nqsir
  • 829
  • 11
  • 19
FrankFan
  • 33
  • 1
  • 6

1 Answers1

2

You can change manually the range of the y axis.

plt.ylim(-0.5, len(names) - 0.5)

For some reason heuristic for estimating the axis range does not get that you are not only interested in the points you are plotting but also in the (-0.5; +0.5) surrounding on both axes.

The lowest points have y coordinate 0, the top-most points have y coordinate len(names) - 1.

Jindřich
  • 10,270
  • 2
  • 23
  • 44
  • Thanks for your solution! Acutally, change the second "+" to "-" works for me, i.e. plt.ylim(-0.5, len(names) - 0.5). Could you also please provide a reasoning? – FrankFan Oct 28 '19 at 12:25
  • Thanks for catching the typo. I actually don't know how the matplotlib heuristics works, so it gets the _ylim_ wrong, but reseting it manually works. – Jindřich Oct 28 '19 at 13:00
  • Thank you so much. One more extra question. Do you know how to make my colorbar (at the ride side) to be aligned with my matrix plot (make it shorter in this case)? – FrankFan Oct 28 '19 at 13:12
  • You can play around with the `fraction` argument in the [`plt.colorbar`](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.colorbar.html) function and set it to something like 0.5. – Jindřich Oct 28 '19 at 14:55
  • You could also write `plt.ylim(sorted(plt.xlim(), reverse=True)`, which could work similarly with `ax` – Magnus Berg Sletfjerding Jan 09 '20 at 13:16