3

I have a df1 Spark dataframe

id     transactions
1      [1, 2, 3, 5]
2      [1, 2, 3, 6]
3      [1, 2, 9, 8]
4      [1, 2, 5, 6]

root
 |-- id: int (nullable = true)
 |-- transactions: array (nullable = false)
     |-- element: int(containsNull = true)
 None

I have a df2 Spark dataframe

items   cost
  [1]    1.0
  [2]    1.0
 [2, 1]  2.0
 [6, 1]  2.0

root
 |-- items: array (nullable = false)
    |-- element: int (containsNull = true)
 |-- cost: int (nullable = true)
 None

I want to check whether all the array elements from items column are in transactions column.

The first row ([1, 2, 3, 5]) contains [1],[2],[2, 1] from items column. Hence I need to sum up their corresponding costs: 1.0 + 1.0 + 2.0 = 4.0

The output I want is

id     transactions    score
1      [1, 2, 3, 5]   4.0
2      [1, 2, 3, 6]   6.0
3      [1, 2, 9, 8]   4.0
4      [1, 2, 5, 6]   6.0

I tried using a loop with collect()/toLocalIterator but it does not seem to be efficient. I will have large data.

I think creating an UDF like this will solve it. But it throws an error.

from pyspark.sql.functions import udf
def containsAll(x, y):
  result = all(elem in x for elem in y)

  if result:
    print("Yes, transactions contains all items")    
  else :
    print("No")

contains_udf = udf(containsAll)
dataFrame.withColumn("result", contains_udf(df2.items, df1.transactions)).show()

Is there any other way around?

ZygD
  • 22,092
  • 39
  • 79
  • 102
priya
  • 375
  • 5
  • 22
  • You need to join the two DataFrames, `groupby`, and `sum` (don't use loops or `collect`). What is the schema of your dataframes? [edit] your question with `df.printSchema()`. I assume those lists are arrays of ints - if so, here's a post on how to join the two dataframes: [PySpark Join on Values Within A List](https://stackoverflow.com/questions/36108620/pyspark-join-and-operation-on-values-within-a-list-in-column) – pault Feb 26 '19 at 21:15
  • @priya what are the relative sizes of `df1` and `df2`? – cph_sto Feb 27 '19 at 13:28
  • @cph_sto df1 may have 100000 rows and number of elements in the transactions could be 1000 to 10,000. df2 can contain double or triple the number of rows as in df1. – priya Feb 27 '19 at 20:21
  • What version of spark are you using? – Shaido Mar 01 '19 at 05:35
  • @Shaido spark 2.3.3 – priya Mar 01 '19 at 08:08

2 Answers2

7

A valid udf before 2.4 (note that it hast to return something

from pyspark.sql.functions import udf

@udf("boolean")
def contains_all(x, y):
    if x is not None and y is not None:
        return set(y).issubset(set(x))

In 2.4 or later no udf is required:

from pyspark.sql.functions import array_intersect, size

def contains_all(x, y):
    return size(array_intersect(x, y)) == size(y)

Usage:

from pyspark.sql.functions import col, sum as sum_, when

df1 = spark.createDataFrame(
   [(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
   ("id", "transactions")
)

df2 = spark.createDataFrame(
    [([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
    ("items", "cost")
)


(df1
    .crossJoin(df2).groupBy("id", "transactions")
    .agg(sum_(when(
        contains_all("transactions", "items"), col("cost")
    )).alias("score"))
    .show())

The result:

+---+------------+-----+                                                        
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
+---+------------+-----+

If df2 is small it could preferred to use it as a local variable:

items = sc.broadcast([
    (set(items), cost) for items, cost in df2.select("items", "cost").collect()
])

def score(y):
    @udf("double")
    def _(x):
        if x is not None:
            transactions = set(x)
            return sum(
                cost for items, cost in y.value 
                if items.issubset(transactions))
    return _


df1.withColumn("score", score(items)("transactions")).show()
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
+---+------------+-----+

Finally it is possible to explode and join

from pyspark.sql.functions import explode

costs = (df1
    # Explode transactiosn
    .select("id", explode("transactions").alias("item"))
    .join(
        df2 
            # Add id so we can later use it to identify source
            .withColumn("_id", monotonically_increasing_id().alias("_id"))
             # Explode items
            .select(
                "_id", explode("items").alias("item"), 
                # We'll need size of the original items later
                size("items").alias("size"), "cost"), 
         ["item"])
     # Count matches in groups id, items
     .groupBy("_id", "id", "size", "cost")
     .count()
     # Compute cost
     .groupBy("id")
     .agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))

costs.show()
+---+-----+                                                                      
| id|score|
+---+-----+
|  1|  4.0|
|  3|  4.0|
|  2|  6.0|
|  4|  6.0|
+---+-----+

and then join the result back with original df1,

df1.join(costs, ["id"])

but that's much less straightforward solution, and requires multiple shuffles. It might be still preferable over Cartesian product (crossJoin), but it will depend on the actual data.

user10938362
  • 3,991
  • 2
  • 12
  • 29
  • Thanks a lot for the help. with 2.4 standalone, i tried your code (contains_all with array_intersect method). but it throws Py4JJavaError: An error occurred while calling o718.showString. Caused by: java.net.SocketTimeoutException: Accept timed out. – priya Mar 04 '19 at 10:26
  • which jdk and spark version do you use? – priya Mar 04 '19 at 10:38
  • 2.4.0, JDK8 (it's the version latest supported by Apache Spark at the moment). – user10938362 Mar 04 '19 at 11:34
  • I prefer to use Explode and join method. Cartesian product and broadcasting will be too expensive for me. Thanks for all the explanations!! – priya Mar 05 '19 at 20:03
  • But the code does not work when the items and transactions has duplicate entities. eg. Transactions = [1,2,1] items=[1,2,1], Though the items [1,2,1] are present in transactions. – priya Mar 27 '19 at 11:35
  • in broadcast method, can we use df1 as a local variable instead of df2 and achieve the result? – priya Apr 30 '19 at 11:03
2

Spark 3.0+ has one more option using forall

F.expr("forall(look_for, x -> array_contains(look_in, x))")

Syntax alternative for Spark 3.1+ - F.forall('look_for', lambda x: F.array_contains('look_in', x))


Comparing this with the option (array_intersect from Spark 2.4)

F.size(F.array_intersect('look_for', 'look_in')) == F.size('look_for')

They differ in dealing with duplicates and null values.

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(['a', 'b', 'c'], ['a']),
     (['a', 'b', 'c'], ['d']),
     (['a', 'b', 'c'], ['a', 'b']),
     (['a', 'b', 'c'], ['c', 'd']),
     (['a', 'b', 'c'], ['a', 'b', 'c']),
     (['a', 'b', 'c'], ['a', None]),
     (['a', 'b',None], ['a', None]),
     (['a', 'b',None], ['a']),
     (['a', 'b',None], [None]),
     (['a', 'b', 'c'], None),
     (None, ['a']),
     (None, None),
     (['a', 'b', 'c'], ['a', 'a']),
     (['a', 'a', 'a'], ['a']),
     (['a', 'a', 'a'], ['a', 'a', 'a']),
     (['a', 'a', 'a'], ['a', 'a',None]),
     (['a', 'a',None], ['a', 'a', 'a']),
     (['a', 'a',None], ['a', 'a',None])],
    ['look_in', 'look_for'])
df = df.withColumn('spark_3_0', F.expr("forall(look_for, x -> array_contains(look_in, x))"))
df = df.withColumn('spark_2_4', F.size(F.array_intersect('look_for', 'look_in')) == F.size('look_for'))

enter image description here

Removing nulls from inside arrays may be useful in some cases, it's easiest done using array_compact from Spark 3.4+.

ZygD
  • 22,092
  • 39
  • 79
  • 102