Precision, recall and f1-score values depend on the probability threshold. Changes in the threshold that we select to use as a cut-off to determine that a sample belongs to the positive class will affect the precision, recall and therefore f1-score. I share my attempt to plot precision, recall and f1-score depending on discrimination threshold. The plot also determines the optimal threshold for the dataset and the model to classify a sample as a member of the positive class. The optimal threshold is that at which f1-score is highest by default.
import pandas as pd
import pathlib
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from sklearn.metrics import confusion_matrix as cm_sklearn
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
def plot_discrimination_threshold(clf, X_test, y_test, argmax='f1', title='Metrics vs Discriminant Threshold', fig_size=(10, 8), dpi=100, save_fig_path=None):
"""
Plot precision, recall and f1-score vs discriminant threshold for the given pipeline model
Parameters
----------
clf : estimator instance (either sklearn.Pipeline, imblearn.Pipeline or a classifier)
PRE-FITTED classifier or a PRE-FITTED Pipeline in which the last estimator is a classifier.
X_test : pandas.DataFrame of shape (n_samples, n_features)
Test features.
y_test : pandas.Series of shape (n_samples,)
Target values.
argmax : str, default: 'f1'
Annotate the threshold maximized by the supplied metric. Options: 'f1', 'precision', 'recall'
title : str, default ='FPR and FNR vs Discriminant Threshold'
Plot title.
fig_size : tuple, default = (10, 8)
Size (inches) of the plot.
dpi : int, default = 100
Image DPI.
save_fig_path : str, defaut=None
Full path where to save the plot. Will generate the folders if they don't exist already.
Returns
-------
fig : Matplotlib.pyplot.Figure
Figure from matplotlib
ax : Matplotlib.pyplot.Axe
Axe object from matplotlib
"""
thresholds = np.linspace(0, 1, 100)
precision_ls = []
recall_ls = []
f1_ls = []
fpr_ls = []
fnr_ls = []
# obtain probabilities
probs = clf.predict_proba(X_test)[:,1]
for threshold in thresholds:
# obtain class prediction based on threshold
y_predictions = np.where(probs>=threshold, 1, 0)
# obtain confusion matrix
tn, fp, fn, tp = cm_sklearn(y_test, y_predictions).ravel()
# obtain FRP and FNR
FPR = fp / (tn + fp)
FNR = fn / (tp + fn)
# obtain precision, recall and f1 scores
precision = precision_score(y_test, y_predictions, average='binary')
recall = recall_score(y_test, y_predictions, average='binary')
f1 = f1_score(y_test, y_predictions, average='binary')
precision_ls.append(precision)
recall_ls.append(recall)
f1_ls.append(f1)
fpr_ls.append(FPR)
fnr_ls.append(FNR)
metrics = pd.concat([
pd.Series(precision_ls),
pd.Series(recall_ls),
pd.Series(f1_ls),
pd.Series(fpr_ls),
pd.Series(fnr_ls)], axis=1)
metrics.columns = ['precision', 'recall', 'f1', 'fpr', 'fnr']
metrics.index = thresholds
plt.rcParams["figure.facecolor"] = 'white'
plt.rcParams["axes.facecolor"] = 'white'
plt.rcParams["savefig.facecolor"] = 'white'
fig, ax = plt.subplots(1, 1, figsize=fig_size, dpi=dpi)
ax.plot(metrics['precision'], label='Precision')
ax.plot(metrics['recall'], label='Recall')
ax.plot(metrics['f1'], label='f1')
ax.plot(metrics['fpr'], label='False Positive Rate (FPR)', linestyle='dotted')
ax.plot(metrics['fnr'], label='False Negative Rate (FNR)', linestyle='dotted')
# Draw a threshold line
disc_threshold = round(metrics[argmax].idxmax(), 2)
ax.axvline(x=metrics[argmax].idxmax(), color='black', linestyle='dashed', label="$t_r$="+str(disc_threshold))
ax.xaxis.set_major_locator(MultipleLocator(0.1))
ax.xaxis.set_major_formatter('{x:.1f}')
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_major_formatter('{x:.1f}')
ax.xaxis.set_minor_locator(MultipleLocator(0.05))
ax.yaxis.set_minor_locator(MultipleLocator(0.05))
ax.tick_params(which='both', width=2)
ax.tick_params(which='major', length=7)
ax.tick_params(which='minor', length=4, color='black')
plt.grid(True)
plt.xlabel('Probability Threshold', fontsize=18)
plt.ylabel('Scores', fontsize=18)
plt.title(title, fontsize=18)
leg = ax.legend(loc='best', frameon=True, framealpha=0.7)
leg_frame = leg.get_frame()
leg_frame.set_color('gold')
plt.show()
if (save_fig_path != None):
path = pathlib.Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, dpi=dpi)
return fig, ax, disc_threshold
