I have two data frames in PySpark: df1
+---+-----------------+
|id1| items1|
+---+-----------------+
| 0| [B, C, D, E]|
| 1| [E, A, C]|
| 2| [F, A, E, B]|
| 3| [E, G, A]|
| 4| [A, C, E, B, D]|
+---+-----------------+
and df2
:
+---+-----------------+
|id2| items2|
+---+-----------------+
|001| [B]|
|002| [A]|
|003| [C]|
|004| [E]|
+---+-----------------+
I would like to create a new column in df1
that would update values in
items1
column, so that it only keeps values that also appear (in any row of) items2
in df2
. The result should look as follows:
+---+-----------------+----------------------+
|id1| items1| items1_updated|
+---+-----------------+----------------------+
| 0| [B, C, D, E]| [B, C, E]|
| 1| [E, A, C]| [E, A, C]|
| 2| [F, A, E, B]| [A, E, B]|
| 3| [E, G, A]| [E, A]|
| 4| [A, C, E, B, D]| [A, C, E, B]|
+---+-----------------+----------------------+
I would normally use collect() to get a list of all values in items2
column and then use a udf applied to each row in items1
to get an intersection. But the data is extremely large (over 10 million rows) and I cannot use collect() to get such list. Is there a way to do this while keeping data in a data frame format? Or some other way without using collect()?