The Scenario
I have a dataframe containing the following data:
import pandas as pd
from pyspark.sql.types import ArrayType, StringType, IntegerType, FloatType, StructType, StructField
import pyspark.sql.functions as F
a = [1,2,3]
b = [['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']]
df = pd.DataFrame({
'id': a,
'list1': b,
})
df=spark.createDataFrame(df)
df.printSchema()
df.show()
+---+---------+
| id| list1|
+---+---------+
| 1|[a, b, c]|
| 2|[d, e, f]|
| 3|[g, h, i]|
+---+---------+
I also have a static list containing the following values
list2 = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
What I want to do
I want to compare each value of list2
to each value of list1
in my data, and build an array of 0/1 values with 1 indicating that the value of list2
was or was not present in list1
.
The resulting output should look like this:
+---+-----------+-----------------------------+
| id| list1 | result |
+---+-----------+-----------------------------+
| 1| [a, b, c] | [1, 1, 1, 0, 0, 0, 0, 0, 0] |
| 2| [d, e, f] | [0, 0, 0, 1, 1, 1, 0, 0, 0] |
| 3| [g, h, i] | [0, 0, 0, 0, 0, 0, 1, 1, 1] |
+---+-----------+-----------------------------+
I need the results in this format because I am eventually going to be multiplying the result
arrays by a scaling factor.
My attempt
# Insert the new_list into the dataframe
df = df.withColumn("list2", F.array([F.lit(x) for x in new_list]))
# Get the result arrays
differencer = F.udf(lambda list1, list2: F.array([1 if x in list1 else 0 for x in list2]), ArrayType(IntegerType()))
df = df.withColumn('result', differencer('list1', 'list2'))
df.show()
However, I get the following error:
An error was encountered:
An error occurred while calling o151.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 11.0 failed 4 times, most recent failure: Lost task 0.3 in stage 11.0 (TID 287) (ip-10-0-0-142.ec2.internal executor 8): java.lang.RuntimeException: Failed to run command: /usr/bin/virtualenv -p python3 --system-site-packages virtualenv_application_1665327460183_0007_0
at org.apache.spark.api.python.VirtualEnvFactory.execCommand(VirtualEnvFactory.scala:120)
at org.apache.spark.api.python.VirtualEnvFactory.setupVirtualEnv(VirtualEnvFactory.scala:78)
at org.apache.spark.api.python.PythonWorkerFactory.<init>(PythonWorkerFactory.scala:94)
at org.apache.spark.SparkEnv.$anonfun$createPythonWorker$1(SparkEnv.scala:125)
at scala.collection.mutable.HashMap.getOrElseUpdate(HashMap.scala:86)
at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:125)
at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:162)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:81)
at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:130)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:863)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:863)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:133)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1474)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)
I've tried dozens of iterations and approaches, but literally everything I do results in the above error.
How can I get this to work? Ideally without having to insert list2
into the dataframe prior to running the comparison.
Thanks