1

everyone. I have a problem when plotting heatmap with matplotlib.

In short, the output heatmap figure is 'cropped'. I want to know WHY this problem arises and HOW to solve it.

The problem is described as follows.

1. Environment

The program is created in below environment.

matplotlib==3.1.1
numpy==1.16.4
pandas==0.24.2

All libs are installed in conda virtual environment with pip.

2. Reference code

My code is created based on the official document: Creating annotated heatmaps

3. My code

My code is given as follows:


    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    from mpl_toolkits.axes_grid1 import make_axes_locatable


    def heatmap(data, ax, fmt='{x:d}'):
    """Plot a heatmap with matplotlib.

    Args:
        data (pandas.DataFrame): the matrix to be labeled.
        ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
        fmt (string, optional): the format of the annotations inside the
            heatmap. This should either use the string format method,
            e.g. "{x:.2f}" or "{x:d}".

    Returns:
        ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
    """

    if not isinstance(data, pd.DataFrame):
        raise TypeError('the input data should be a pandas DataFrame object')

    if data.shape[0] != data.shape[1]:
        raise ValueError('the data should be a square-matrix')

    im = ax.imshow(data.values, cmap='magma', interpolation='nearest')

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.15)
    cbar = plt.colorbar(im, cax=cax,
                        ticks=mpl.ticker.MaxNLocator(nbins=6))
    cbar.ax.tick_params(labelsize=12)
    cbar.outline.set_visible(False)

    ax.set_xticks(list(np.arange(data.shape[1])))
    ax.set_yticks(list(np.arange(data.shape[0])))

    column_labels = list(data.columns.values.astype(str))
    ax.set_xticklabels(column_labels, fontsize=12, rotation='vertical')
    ax.set_yticklabels(column_labels, fontsize=12)

    for spine in ax.spines.values():
        spine.set_visible(False)

    valfmt = mpl.ticker.StrMethodFormatter(fmt)

    # Change the text's color depending on the background.
    text_colors = ['white', '0.15']
    threshold = 0.6

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            use_dark = im.norm(data.iloc[i, j]) > threshold
            im.axes.text(
                j, i, valfmt(data.iloc[i, j], None),
                ha="center", va="center", fontsize=12,
                color=text_colors[int(use_dark)])

    return ax

4. Input data

The input data is a very simple pandas DataFrame.


    confusion_df = pd.DataFrame([[89, 4], [5, 80]], index=['dog', 'cat'], columns=['dog', 'cat'])

5. Output

I create a confusion map figure as follows.


    fig, ax = plt.subplots(figsize=(4, 4))

    heatmap(confusion_df, ax=ax, fmt='{x:d}')

    ax.set_xlabel('Prediction', fontsize=16)
    ax.set_ylabel('Ground truth', fontsize=16)
    fig.tight_layout()

    fig.savefig('confusion_map.png')
    plt.show()

6. Result

The result is very strange, as I said bebore, the figure is 'cropped'. The 'cropped' heatmap

7. The result I want to get.

The correct output should be a square like below. You can also check the results given in the reference document in 2.

The figure I want to get

8. The methods I tried

I tried the following methods but they do NOT work.

(a) compared my code with the code given in the reference document (see 2)

(b) set larger figure size, say fig, ax = plt.subplots(figsize=(6, 6)) in 3

(c) remove fig.tight_layout() in 5

Thank you very much!

zchenkui
  • 139
  • 1
  • 10

0 Answers0