If you specifically need to generate ROC curves for different thresholds, one approach could be to generate a list of threshold values you're interested in and fit/transform on your dataset for each threshold. Or you could manually calculate the ROC curve for each threshold point using the probability
field in the response from model.transform(test)
.
Alternatively, you can use BinaryClassificationMetrics to extract a curve plotting various metrics (F1 score, precision, recall) by threshold.
Unfortunately it appears the PySpark version doesn't implement most of the methods the Scala version does, so you'd need to wrap the class to do it in Python.
For example:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
def __init__(self, *args):
super(CurveMetrics, self).__init__(*args)
def _to_list(self, rdd):
points = []
# Note this collect could be inefficient for large datasets
# considering there may be one probability per datapoint (at most)
# The Scala version takes a numBins parameter,
# but it doesn't seem possible to pass this from Python to Java
for row in rdd.collect():
# Results are returned as type scala.Tuple2,
# which doesn't appear to have a py4j mapping
points += [(float(row._1()), float(row._2()))]
return points
def get_curve(self, method):
rdd = getattr(self._java_model, method)().toJavaRDD()
return self._to_list(rdd)
Usage:
import matplotlib.pyplot as plt
preds = predictions.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label'])))
# Returns as a list (false positive rate, true positive rate)
points = CurveMetrics(preds).get_curve('roc')
plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(x_val, y_val)
Results in:

Here's an example of an F1 score curve by threshold value if you aren't married to ROC:
