everyone. I have a problem when plotting heatmap with matplotlib.
In short, the output heatmap figure is 'cropped'. I want to know WHY this problem arises and HOW to solve it.
The problem is described as follows.
1. Environment
The program is created in below environment.
matplotlib==3.1.1
numpy==1.16.4
pandas==0.24.2
All libs are installed in conda virtual environment with pip.
2. Reference code
My code is created based on the official document: Creating annotated heatmaps
3. My code
My code is given as follows:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
def heatmap(data, ax, fmt='{x:d}'):
"""Plot a heatmap with matplotlib.
Args:
data (pandas.DataFrame): the matrix to be labeled.
ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
fmt (string, optional): the format of the annotations inside the
heatmap. This should either use the string format method,
e.g. "{x:.2f}" or "{x:d}".
Returns:
ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
"""
if not isinstance(data, pd.DataFrame):
raise TypeError('the input data should be a pandas DataFrame object')
if data.shape[0] != data.shape[1]:
raise ValueError('the data should be a square-matrix')
im = ax.imshow(data.values, cmap='magma', interpolation='nearest')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.15)
cbar = plt.colorbar(im, cax=cax,
ticks=mpl.ticker.MaxNLocator(nbins=6))
cbar.ax.tick_params(labelsize=12)
cbar.outline.set_visible(False)
ax.set_xticks(list(np.arange(data.shape[1])))
ax.set_yticks(list(np.arange(data.shape[0])))
column_labels = list(data.columns.values.astype(str))
ax.set_xticklabels(column_labels, fontsize=12, rotation='vertical')
ax.set_yticklabels(column_labels, fontsize=12)
for spine in ax.spines.values():
spine.set_visible(False)
valfmt = mpl.ticker.StrMethodFormatter(fmt)
# Change the text's color depending on the background.
text_colors = ['white', '0.15']
threshold = 0.6
for i in range(data.shape[0]):
for j in range(data.shape[1]):
use_dark = im.norm(data.iloc[i, j]) > threshold
im.axes.text(
j, i, valfmt(data.iloc[i, j], None),
ha="center", va="center", fontsize=12,
color=text_colors[int(use_dark)])
return ax
4. Input data
The input data is a very simple pandas DataFrame.
confusion_df = pd.DataFrame([[89, 4], [5, 80]], index=['dog', 'cat'], columns=['dog', 'cat'])
5. Output
I create a confusion map figure as follows.
fig, ax = plt.subplots(figsize=(4, 4))
heatmap(confusion_df, ax=ax, fmt='{x:d}')
ax.set_xlabel('Prediction', fontsize=16)
ax.set_ylabel('Ground truth', fontsize=16)
fig.tight_layout()
fig.savefig('confusion_map.png')
plt.show()
6. Result
The result is very strange, as I said bebore, the figure is 'cropped'.
7. The result I want to get.
The correct output should be a square like below. You can also check the results given in the reference document in 2.
8. The methods I tried
I tried the following methods but they do NOT work.
(a) compared my code with the code given in the reference document (see 2)
(b) set larger figure size, say fig, ax = plt.subplots(figsize=(6, 6))
in 3
(c) remove fig.tight_layout()
in 5
Thank you very much!