4

I'm running the code below for a confusion matrix. The output looked great until I reset the notebook kernal. I didn't change the code, but now it looks squished (Fig 1). It corrects when I delete the plt.yticks line (Fig 2), but I want those labels. This is probably simple, but I'm new to Python.

import itertools
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion Matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Source: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    # Plot the confusion matrix
    plt.figure(figsize = (6, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, size = 25)
    plt.colorbar(aspect=5)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, size = 12)
    plt.yticks(tick_marks, classes, size = 12)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.

    # Labeling the plot
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), fontsize = 20,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.grid(False)
    plt.tight_layout()
    plt.ylabel('Actual label', size = 15)
    plt.xlabel('Predicted label', size = 15)

cm = confusion_matrix(y_test, y_pred)
plot_confusion_matrix(cm, classes = ['Good Mental Health', 'Poor Mental Health'],
                      title = 'Confusion Matrix')

enter image description here

enter image description here

Dr.Data
  • 167
  • 1
  • 10

1 Answers1

0

Try adding these lines at the end of the code.

plt.tight_layout()
plt.show()

This alone should help a lot. Some more advice:

1) I think what is happening is that you ask for a 6x6 inches figures, and this space includes the labels. A larger figure may help.

2) you could try to improve the way the space you require is used. I would for sure ask for the label to be in two different lines: I guess you have

 tick_marks = ['good mental health', 'poor mental health']

somewhere in your code. Doing

 tick_marks = ['good \nmental health', 'poor \nmental health']

may help as well.

3) Another way to improve the way you use space is to rotate the ylabels:

 plt.yticks(tick_marks, classes, size = 12, rotation='vertical')

You should try different combinations and see what happens.

GRquanti
  • 527
  • 8
  • 23
  • 5
    Thanks for the tips. I do have `plt.tight_layout()` & `plt.show()` in the code already. I found that adding `plt.ylim([1.5, -.5])` fixed the issue, though I'm still confused as to why is had been working fine earlier in the day and then stopped suddenly. Oh well, it's fixed! – Dr.Data Sep 15 '19 at 00:29
  • 1
    One should know what you did the first time. In general, you can trust what happens by cleaning kernels and running it from the beginning. This will happen again and again. – GRquanti Sep 16 '19 at 02:56
  • I've encountered the same problem with new matplotlib version, and only plt.ylim([1.5, -.5]) solved it for me – Jenia Golbstein Nov 10 '19 at 14:54