15

I'm looking for someone who can help me to plot my Confusion Matrix. I need this for a term paper at the university. However I have very little experience in programming.

In the pictures you can see the classification report and the structure of my y_test and X_test in my case dtree_predictions.

I would be happy if someone can help me, because I have tried so many things but I just don't get a solution, only error messages.

X_train, X_test, y_train, y_test = train_test_split(X, Y_profile, test_size = 0.3, random_state = 30)

dtree_model = DecisionTreeClassifier().fit(X_train,y_train)
dtree_predictions = dtree_model.predict(X_test)

print(metrics.classification_report(dtree_predictions, y_test))
              precision    recall  f1-score   support

       0       1.00      1.00      1.00       222
       1       1.00      1.00      1.00       211
       2       1.00      1.00      1.00       229
       3       0.96      0.97      0.96       348
       4       0.89      0.85      0.87        93
       5       0.86      0.86      0.86       105
       6       0.94      0.93      0.94       116
       7       1.00      1.00      1.00       364
       8       0.99      0.97      0.98       139
       9       0.98      0.99      0.99       159
      10       0.97      0.96      0.97       189
      11       0.92      0.92      0.92       124
      12       0.92      0.92      0.92       119
      13       0.95      0.96      0.95       230
      14       0.98      0.96      0.97       452
      15       0.91      0.96      0.93       210

micro avg       0.96      0.96      0.96      3310
macro avg       0.95      0.95      0.95      3310
weighted avg    0.97      0.96      0.96      3310
samples avg     0.96      0.96      0.96      3310

next I print the metris of the multilabel confusion matrix

from sklearn.metrics import multilabel_confusion_matrix
multilabel_confusion_matrix(y_test, dtree_predictions)

array([[[440,   0],
    [  0, 222]],

   [[451,   0],
    [  0, 211]],

   [[433,   0],
    [  0, 229]],

   [[299,  10],
    [ 15, 338]],

   [[559,  14],
    [ 10,  79]],

   [[542,  15],
    [ 15,  90]],

   [[539,   8],
    [  7, 108]],

   [[297,   0],
    [  1, 364]],

   [[522,   4],
    [  1, 135]],

   [[500,   1],
    [  3, 158]],

   [[468,   8],
    [  5, 181]],

   [[528,  10],
    [ 10, 114]],

   [[534,   9],
    [  9, 110]],

   [[420,   9],
    [ 12, 221]],

   [[201,  19],
    [  9, 433]],

   [[433,   9],
    [ 19, 201]]])

and the structure of y_test and dtree_predictons

print(dtree_predictions)
print(dtree_predictions.shape)

[[0. 0. 1. ... 0. 1. 0.]
[1. 0. 0. ... 0. 1. 0.]
[0. 0. 1. ... 0. 1. 0.]
 ...
[1. 0. 0. ... 0. 0. 1.]
[0. 1. 0. ... 1. 0. 1.]
[0. 1. 0. ... 1. 0. 1.]]
(662, 16)

print(y_test)

      Cooler close to failure  Cooler reduced effiency  Cooler full    effiency  \
1985                      0.0                      0.0                   1.0   
322                       1.0                      0.0                   0.0   
2017                      0.0                      0.0                   1.0   
1759                      0.0                      0.0                   1.0   
1602                      0.0                      0.0                     1.0   
...                       ...                      ...                      ...   
128                       1.0                      0.0                   0.0   
321                       1.0                      0.0                   0.0   
53                        1.0                      0.0                   0.0   
859                       0.0                      1.0                     0.0   
835                       0.0                      1.0                       0.0   

  valve optimal  valve small lag  valve severe lag  \
1985            0.0              0.0               0.0   
322             0.0              1.0               0.0   
2017            1.0              0.0               0.0   
1759            0.0              0.0               0.0   
1602            1.0              0.0               0.0   
...             ...              ...               ...   
128             1.0              0.0               0.0   
321             0.0              1.0               0.0   
53              1.0              0.0               0.0   
859             1.0              0.0               0.0   
835             1.0              0.0               0.0   

  valve close to failure  pump no leakage  pump weak leakage  \
1985                     1.0              0.0                1.0   
322                      0.0              1.0                0.0   
2017                     0.0              0.0                1.0   
1759                     1.0              1.0                0.0   
1602                     0.0              1.0                0.0   
...                      ...              ...                ...   
128                      0.0              1.0                0.0   
321                      0.0              1.0                0.0   
53                       0.0              1.0                0.0   
859                      0.0              1.0                0.0   
835                      0.0              1.0                0.0   

  pump severe leakage  accu optimal pressure  \
1985                  0.0                    0.0   
322                   0.0                    1.0   
2017                  0.0                    0.0   
1759                  0.0                    1.0   
1602                  0.0                    0.0   
...                   ...                    ...   
128                   0.0                    1.0   
321                   0.0                    1.0   
53                    0.0                    1.0   
859                   0.0                    0.0   
835                   0.0                    0.0   

  accu slightly reduced pressure  accu severly reduced pressure  \
1985                             0.0                            1.0   
322                              0.0                            0.0   
2017                             0.0                            1.0   
1759                             0.0                            0.0   
1602                             0.0                            0.0   
...                              ...                            ...   
128                              0.0                            0.0   
321                              0.0                            0.0   
53                               0.0                            0.0   
859                              0.0                            0.0   
835                              0.0                            0.0   

  accu close to failure  stable flag stable  stable flag not stable  
1985                    0.0                 1.0                     0.0  
322                     0.0                 1.0                     0.0  
2017                    0.0                 1.0                     0.0  
1759                    0.0                 1.0                     0.0  
1602                    1.0                 0.0                     1.0  
...                     ...                 ...                     ...  
128                     0.0                 0.0                     1.0  
321                     0.0                 1.0                     0.0  
53                      0.0                 0.0                     1.0  
859                     1.0                 0.0                     1.0  
835                     1.0                 0.0                     1.0  

[662 rows x 16 columns]
Venkatachalam
  • 16,288
  • 9
  • 49
  • 77
user13861437
  • 151
  • 1
  • 1
  • 5
  • 1
    Please add your code as text so that answerers can reproduce your output and help you – Derek O Jul 03 '20 at 20:45
  • I have a question the whole Code from the beginning or only the code frpm the pictures? – user13861437 Jul 03 '20 at 20:51
  • 1
    You should post as much of your code as possible if it's a reasonable. Also instead of hyperlinking to images of your code, post your code as formatted text so that people who want to help you can copy and paste it and try to run it themselves – Derek O Jul 03 '20 at 20:57
  • 1
    When you searched the help or Google for *"plot/get Confusion Matrix"*, what command did you find? Please post your actual code. I think you don't want a plot, you just want a table? Please confirm. Also, you need to tag this [tag:python] so people know what language you're using, and so it gets distributed to the thousands of users in that tag. – smci Jul 03 '20 at 21:19
  • @smci i want to show it like this one in the example, if it is possible. I don't know if plot is the right worde for it :D https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html – user13861437 Jul 03 '20 at 21:25

2 Answers2

18

You could use the ConfusionMatrixDisplay option in sklearn.metrics.

Example:

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_multilabel_classification
from sklearn.tree import DecisionTreeClassifier

X, y = make_multilabel_classification(n_samples=1000,
                                      n_classes=15, random_state=42)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=42)

tree = DecisionTreeClassifier(random_state=42).fit(X_train, y_train)

y_pred = tree.predict(X_test)

f, axes = plt.subplots(3, 5, figsize=(25, 15))
axes = axes.ravel()
for i in range(15):
    disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
                                                   y_pred[:, i]),
                                  display_labels=[0, i])
    disp.plot(ax=axes[i], values_format='.4g')
    disp.ax_.set_title(f'class {i}')
    if i<10:
        disp.ax_.set_xlabel('')
    if i%5!=0:
        disp.ax_.set_ylabel('')
    disp.im_.colorbar.remove()

plt.subplots_adjust(wspace=0.10, hspace=0.1)
f.colorbar(disp.im_, ax=axes)
plt.show()

enter image description here

Venkatachalam
  • 16,288
  • 9
  • 49
  • 77
17

Usually, a confusion matrix is visualized via a heatmap. A function is also created in github to pretty print a confusion matrix. Inspired from it, I have adapted into multilabel scenario where each of the class with the binary predictions (Y, N) are added into the matrix and visualized via heat map.

Here, is the example taking some of the output from the posted code:

Confusion matrix obtained for each of the labels turned into a binary classification problem.

Multilabel confusion matrix puts TN at (0,0) and TP at (1,1) position thanks @Kenneth Witham for pointing out.
import numpy as np

vis_arr = np.asarray([[[440,   0],
    [  0, 222]],

   [[451,   0],
    [  0, 211]],

   [[433,   0],
    [  0, 229]],

   [[299,  10],
    [ 15, 338]],

   [[559,  14],
    [ 10,  79]],

   [[542,  15],
    [ 15,  90]],

   [[539,   8],
    [  7, 108]],

   [[297,   0],
    [  1, 364]],

   [[522,   4],
    [  1, 135]],

   [[500,   1],
    [  3, 158]],

   [[468,   8],
    [  5, 181]],

   [[528,  10],
    [ 10, 114]],

   [[534,   9],
    [  9, 110]],

   [[420,   9],
    [ 12, 221]],

   [[201,  19],
    [  9, 433]],

   [[433,   9],
    [ 19, 201]]])

Manually created class labels c0 to c15.

labels = ["".join("c" + str(i)) for i in range(0, 16)]

Multilabel visualization of confusion matrix adaptation

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def print_confusion_matrix(confusion_matrix, axes, class_label, class_names, fontsize=14):

    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names,
    )

    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cbar=False, ax=axes)
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    axes.set_ylabel('True label')
    axes.set_xlabel('Predicted label')
    axes.set_title("Confusion Matrix for the class - " + class_label)

Updating for multilabel classification visualization

Extending the basic confusion matrix to plot of a grid of subplots with the title as each of the classes. Here the [Y, N] are the defined class labels and can be extended.

fig, ax = plt.subplots(4, 4, figsize=(12, 7))
    
    for axes, cfs_matrix, label in zip(ax.flatten(), vis_arr, labels):
        print_confusion_matrix(cfs_matrix, axes, label, ["N", "Y"])
    
    fig.tight_layout()
    plt.show()

Note: This plot is constructed based on wiki article on confusion matrix

Output:

enter image description here

coldy
  • 2,115
  • 2
  • 17
  • 28