0

I have figure instances, and I want to plot them side by side (e.g. two figures in one row and two columns). Below is the sample code which returns the figure instance.

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

def build_confusion_matrix_test():
    cm = np.array([[379,  49],
                   [ 18 , 261]])
    print(cm)
    
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=[0,1])
        
    title_font = {'size':'13.5'}  # Adjust to fit
    
    disp.plot()
    disp.ax_.set_title("title", fontdict = title_font)
    
    return disp.figure_


# Function call
test_plot = build_confusion_matrix_test()
test_plot

I have many figure instances from different functions and I was expecting something like the below where I try to plot the same figure twice in one row but not sure how to make it work:

fig = plt.figure()

ax1 = fig.add_subplot(1,1)
test_plot

ax2 = fig.add_subplot(1,2)
test_plot
Taha
  • 53
  • 4
  • Not sure why this question was closed. I don't find any of the above links related. This is a figure instance; otherwise, I know how to plot it side by side. – Taha Aug 18 '22 at 11:22
  • I agree with you, but your question wasn't super clear. `ConfusionMatrixDisplay` puts the confusion matrix in a new figure. There is no way to then mash that into a new figure. You should open a feature request on `scikit-learn` and ask them to allow `CMD` to accept an `ax` kwarg so you can pass the axes to the method. Or plot the confusion matrix manually (its not that hard) – Jody Klymak Aug 18 '22 at 15:17
  • 2
    You can pass ax object to `disp.plot(ax)` to plot on, `default=None`. The If `None`, a new figure and axes is created. Once you pass the ax object `fig, (ax1, ax2) = plt.subplots(1, 2) build_confusion_matrix_test(ax=ax1) build_confusion_matrix_test(ax=ax2)` should give you the desired plot – Agaz Wani Aug 19 '22 at 19:42
  • 1
    @JodyKlymak `ax` is already supported (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ConfusionMatrixDisplay.html#sklearn.metrics.ConfusionMatrixDisplay.from_predictions) – Agaz Wani Aug 19 '22 at 19:56

1 Answers1

0

Try this:

import itertools
from mpl_toolkits.axes_grid1 import ImageGrid

classes = ["0", "1"]

fig = plt.figure()
grid = ImageGrid(fig, 111,
                 nrows_ncols=(1,2),
                 axes_pad=0.15,
                 cbar_location="right",
                 cbar_mode="single",
                 cbar_size="7%",
                 cbar_pad=0.15,
                 )


for i, ax in enumerate(grid[:2]):
    cm = np.array([[379,  49],
                   [ 18 , 261]])
    im = ax.imshow(cm, vmin=0, vmax=400)
    ax.set_title("title {}".format(i))
    tick_marks = np.arange(2)
    
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(classes, rotation=45)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes)

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], '.5f'),
                 horizontalalignment="center",
                 color="white")

    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')

fig.tight_layout()
fig.subplots_adjust(right=0.8)
fig.colorbar(im, cax=ax.cax)

plt.show()
AfterFray
  • 1,751
  • 3
  • 17
  • 22