I have a df1
Spark dataframe
id transactions
1 [1, 2, 3, 5]
2 [1, 2, 3, 6]
3 [1, 2, 9, 8]
4 [1, 2, 5, 6]
root
|-- id: int (nullable = true)
|-- transactions: array (nullable = false)
|-- element: int(containsNull = true)
None
I have a df2
Spark dataframe
items cost
[1] 1.0
[2] 1.0
[2, 1] 2.0
[6, 1] 2.0
root
|-- items: array (nullable = false)
|-- element: int (containsNull = true)
|-- cost: int (nullable = true)
None
I want to check whether all the array elements from items column are in transactions column.
The first row ([1, 2, 3, 5]
) contains [1],[2],[2, 1]
from items column. Hence I need to sum up their corresponding costs: 1.0 + 1.0 + 2.0 = 4.0
The output I want is
id transactions score
1 [1, 2, 3, 5] 4.0
2 [1, 2, 3, 6] 6.0
3 [1, 2, 9, 8] 4.0
4 [1, 2, 5, 6] 6.0
I tried using a loop with collect()
/toLocalIterator
but it does not seem to be efficient. I will have large data.
I think creating an UDF like this will solve it. But it throws an error.
from pyspark.sql.functions import udf
def containsAll(x, y):
result = all(elem in x for elem in y)
if result:
print("Yes, transactions contains all items")
else :
print("No")
contains_udf = udf(containsAll)
dataFrame.withColumn("result", contains_udf(df2.items, df1.transactions)).show()
Is there any other way around?