27

I'm using scikit learn, and I want to plot the precision and recall curves. the classifier I'm using is RandomForestClassifier. All the resources in the documentations of scikit learn uses binary classification. Also, can I plot a ROC curve for multiclass?

Also, I only found for SVM for multilabel and it has a decision_function which RandomForest doesn't have

sentence
  • 8,213
  • 4
  • 31
  • 40
John Sall
  • 1,027
  • 1
  • 12
  • 25
  • There is a paragraph with example here:https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html. Is that not what you want? – Yohst May 11 '19 at 13:48
  • https://scikit-learn.org/0.15/auto_examples/plot_precision_recall.html – secretive May 11 '19 at 13:56
  • @Yohst that example uses svm with decision function, and RandomForest doesn't have decision functions. – John Sall May 11 '19 at 14:18

1 Answers1

53

From scikit-learn documentation:

Precision-recall curves are typically used in binary classification to study the output of a classifier. In order to extend the precision-recall curve and average precision to multi-class or multi-label classification, it is necessary to binarize the output. One curve can be drawn per label, but one can also draw a precision-recall curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).

ROC curves are typically used in binary classification to study the output of a classifier. In order to extend ROC curve and ROC area to multi-class or multi-label classification, it is necessary to binarize the output. One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).

Therefore, you should binarize the output and consider precision-recall and roc curves for each class. Moreover, you are going to use predict_proba to get class probabilities.

I divide the code into three parts:

  1. general settings, learning and prediction
  2. precision-recall curve
  3. ROC curve

1. general settings, learning and prediction

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_openml("mnist_784")
y = mnist.target
y = y.astype(np.uint8)
n_classes = len(set(y))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2. precision-recall curve

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

enter image description here

3. ROC curve

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

enter image description here

Federico Gentile
  • 5,650
  • 10
  • 47
  • 102
sentence
  • 8,213
  • 4
  • 31
  • 40
  • 3
    why i'm using OneVsRestClassifier? isn't RandomForest already support multiclass? – John Sall May 11 '19 at 23:04
  • I have those errors when I run the first part: UserWarning: Label not 0 is present in all training example UserWarning: Label not 1 is present in all training example UserWarning: Label not 2 is present in all training example – John Sall May 12 '19 at 07:11
  • Please, note that a warning is NOT an error. Considering this line `Y = label_binarize(mnist.target, classes=[*range(n_classes)])`, you should provide the classes in your dataset. In my example, the classes are `[0,1,2,...,9]`. – sentence May 12 '19 at 09:57
  • How do you create PR curve or a ROC curve with the micro-average? As far as I know, if you have 3 classes, you would obtain 3 probability vectors, 1 with the probability of each class. And then the observation gets assigned to the class with the highest probability. That is, independent of a threshold. But for ROC and PR curves, you need a threshold, so how would you do the microaverage? how to you assign an observation to class based on a specific threshold? – Sole Galli Jul 26 '21 at 09:02
  • I just tried to reverse calculate the precision and recall when the threshold is equal to 0 and see if it matches the one given by the classification_report() function but it returns strangely different results. I am addressing this problem here: https://stats.stackexchange.com/questions/559203/why-does-precision-recall-curve-return-similar-but-not-equal-values-than-confu?noredirect=1#comment1028225_559203 – Federico Gentile Jan 05 '22 at 13:46
  • @JohnSall Replace label_binarize line with Y = label_binarize(y, classes=[0,1, 2, 3,4,5,6,7,8,9]) – Naseeb Gill Mar 22 '22 at 19:35