Using the latest matplotlib version on Python 3.7 I try to plot (and save to png) a confusion matrix. While the resulting figure in in principle fine, the cells are sized different, see here:
As you can see in the screenshot, actually only the middle cell is sized correctly, all others, i.e., all border cells in this case, seem to have only have half or even quarter the size compared to the middle cell.
The source code I'm running is simply:
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
def create_save_plotted_confusion_matrix(conf_matrix, expected_labels, basepath):
ax, title = plot_confusion_matrix(conf_matrix, expected_labels, normalize=False)
filepath = os.path.join(basepath, '.png')
plt.savefig(filepath, bbox_inches='tight')
def plot_confusion_matrix(cm, classes, normalize=False, title=None, cmap=plt.cm.Blues):
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
else:
pass
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax, title
if __name__ == '__main__':
y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
confmat = confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])
create_save_plotted_confusion_matrix(confmat, ["ant", "bird", "cat"], '.')