-2

I have created a classification matrix for multi-label classification to evaluate the performance of the MLPClassifier model. The confusion matrix output should be 10x10 but at times I get 8x8 as it doesn't show label values for either 1 or 2 class labels as you can see from the confusion matrix heatmap below the code whenever I run the whole Jupyter notebook. The class labels of true and predicted labels are from 1 to 10 (unordered). Is it because of a code bug or it just depends on the random input samples the test dataset accepts when the data is split into train and test sets? How should I fix this? The implementation of the code looks like this:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
print(cm)
str(cm)

Out:    [[20  0  0  1  0  5  1  0]
         [ 3  0  0  0  0  0  0  0]
         [ 1  1  0  1  0  1  0  0]
         [ 3  0  0  0  0  3  1  1]
         [ 0  0  0  0  0  1  0  0]
         [ 3  0  0  1  0  2  1  1]
         [ 3  0  0  0  0  0  0  2]
         [ 1  0  0  0  0  0  0  1]]

'[[20  0  0  1  0  5  1  0]\n [ 3  0  0  0  0  0  0  0]\n [ 1  1  0  1  0  
 1  0  0]\n [ 3  0  0  0  0  3  1  1]\n [ 0  0  0  0  0  1  0  0]\n [ 3  0  
 0  1  0  2  1  1]\n [ 3  0  0  0  0  0  0  2]\n [ 1  0  0  0  0  0  0  
 1]]'

import matplotlib.pyplot as plt
import seaborn as sns
side_bar = [1,2,3,4,5,6,7,8,9,10]
f, ax = plt.subplots(figsize=(12,12))
sns.heatmap(cm, annot=True, linewidth=.5, linecolor="r", fmt=".0f", ax = ax)
ax.set_xticklabels(side_bar)
ax.set_yticklabels(side_bar)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

confusion matrix heatmap

  • I would guess that it is not a bug in the sklearn implementation, but something in your code, maybe related to your split. Have you checked `assert len(set(y_test)) == 10`? – flyingdutchman May 27 '21 at 12:18
  • Does this answer your question? [Stratified Train/Test-split in scikit-learn](https://stackoverflow.com/questions/29438265/stratified-train-test-split-in-scikit-learn) – flyingdutchman May 27 '21 at 12:24
  • I have checked now using assert len(set(y_Test_ == 10 but I'm getting an error "TypeError: unhashable type: 'numpy.ndarray'". – Srikanth Tangirala May 27 '21 at 13:21
  • I have also tried the solution suggested in the stratified train/test split link which you shared, but I get an error "ValueError: The least populated class in y has only 1 member, too few. The minimum number of groups for any class cannot be less than 2." – Srikanth Tangirala May 27 '21 at 13:28
  • What about using `side_bar = np.unique(y_test) ` as there clearly are only 8 labels used in `y_test`? What did you change versus your [previous question](https://stackoverflow.com/questions/67691498/confusion-matrix-output-missing-some-labels-for-multi-label-classification)? The confusion matrix never contains rows/columns for values that aren't part of the data. Note that if you don't have enough data to make sure that a test set contains all the labels, your classifier can't work properly. You really need enough examples for each label, both in the test and in the train data. – JohanC May 27 '21 at 14:40
  • I have tried using side_bar = np.unique(y_test) but still 1 label is missing. I didn't change much in this post and I reposted it because I didn't get any replies until now. I just added another question compared to the previous question that whether it's due to insufficient input samples or due to low test size as I have set it to 30% like. I'm guessing that the test dataset needs enough examples for every label like you mentioned. – Srikanth Tangirala May 27 '21 at 15:18

1 Answers1

1

I think there is confusion here! a confusion matrix is set(y_test) + set(y_pred). so if this comes to be 8. then the matrix will be 8 X 8. Now, if you have more labels you want to print even though it does not matter if it's all zero. then while building the matrix , you need to feed "labels" parameter.

y_true = [2, 0, 2, 2, 0, 1,5]
y_pred = [0, 0, 2, 2, 0, 2,4]
confusion_matrix = confusion_matrix(y_true, y_pred,labels=[0,1,2,3,4,5,6])

as you can see, 6 is really not there in y_true or y_pred, you will zeros for it. enter image description here

simpleApp
  • 2,885
  • 2
  • 10
  • 19