2

After applying a RandomForestClassifier for binary classification and predicting on a dataset, I obtain a transformed dataframe df with label, prediction and probability columns.
The goal:
I want to create a new column "prob_flag" which is the probability for predicting the label '1'. It is the second element of the array containing probabilities (itself the third element of the first array).

I looked into similar topics but I get an error not encountered in these topics.

df.show()
label   prediction                 probability
  0           0           [1,2,[],[0.7558548984793847,0.2441451015206153]]
  0           0           [1,2,[],[0.5190322149055472,0.4809677850944528]]
  0           1           [1,2,[],[0.4884140358521083,0.5115859641478916]]
  0           1           [1,2,[],[0.4884140358521083,0.5115859641478916]]
  1           1           [1,2,[],[0.40305518381637956,0.5969448161836204]]
  1           1           [1,2,[],[0.40570407426458577,0.5942959257354141]]

# The probability column is VectorUDT and looks like an array of dim 4 that contains probabilities of predicted variables I want to retrieve  
df.schema
StructType(List(StructField(label,DoubleType,true),StructField(prediction,DoubleType,false),StructField(probability,VectorUDT,true)))

# I tried this:
import pyspark.sql.functions as f

df.withColumn("prob_flag", f.array([f.col("probability")[3][1])).show()

"Can't extract value from probability#6225: need struct type but got struct<type:tinyint,size:int,indices:array<int>,values:array<double>>;"

I want to create a new column "prob_flag" which is the probability for predicting the label '1'. It is the second number of the array containing probabilities e.g. 0.24, 0.48, 0.51, 0.51, 0.59, 0.59.

LePuppy
  • 562
  • 6
  • 12

1 Answers1

10

Unfortunately you cannot extract fields of a VectorUDT as if it was an ArrayType.

You must use an udf instead :

from pyspark.sql.types import DoubleType
from pyspark.sql.functions import udf, col

def extract_prob(v):
    try:
        return float(v[1])  # Your VectorUDT is of length 2
    except ValueError:
        return None

extract_prob_udf = udf(extract_prob, DoubleType())

df2 = df.withColumn("prob_flag", extract_prob_udf(col("probability")))
Pierre Gourseaud
  • 2,347
  • 13
  • 24