3

I have the following sample dataframe

fruit_list = ['apple', 'apple', 'orange', 'apple']
qty_list = [16, 2, 3, 1]
spark_df = spark.createDataFrame([(101, 'Mark', fruit_list, qty_list)], ['ID', 'name', 'fruit', 'qty'])

and I would like to create another column which contains a result similar to what I would achieve with a pandas groupby('fruit').sum()

        qty
fruits     
apple    19
orange    3

The above result could be stored in the new column in any form (either a string, dictionary, list of tuples...).

I've tried an approach similar to the following one which does not work

sum_cols = udf(lambda x: pd.DataFrame({'fruits': x[0], 'qty': x[1]}).groupby('fruits').sum())
spark_df.withColumn('Result', sum_cols(F.struct('fruit', 'qty'))).show()

One example of result dataframe could be

+---+----+--------------------+-------------+-------------------------+
| ID|name|               fruit|          qty|                   Result|
+---+----+--------------------+-------------+-------------------------+
|101|Mark|[apple, apple, or...|[16, 2, 3, 1]|[(apple,19), (orange,3)] |
+---+----+--------------------+-------------+-------------------------+

Do you have any suggestion on how I could achieve that?

Thanks

Edit: running on Spark 2.4.3

crash
  • 4,152
  • 6
  • 33
  • 54
  • What is your desired output? It's unclear from the description, please show it explicitly. – pault Jul 31 '19 at 13:30
  • thanks for your comment, done! – crash Jul 31 '19 at 13:37
  • 1
    What version of spark? If it's spark 2.4+ you can use [`array_zip`](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.arrays_zip). Older versions make this a little more difficult. – pault Jul 31 '19 at 13:40
  • I'm running on 2.4.3 could you kindly provide me with an example usage for that in my case? – crash Jul 31 '19 at 13:52
  • In my (limited) experience, I've seen "native" pyspark code perform 10x faster than UDFs (and especially UDAFs), even when an `explode` was being used. Just something to keep in mind.. – Marco Aug 01 '19 at 21:38

3 Answers3

4

As @pault mentioned, as of Spark 2.4+, you can use Spark SQL built-in function to handle your task, here is one way with array_distinct + transform + aggregate:

from pyspark.sql.functions import expr

# set up data
spark_df = spark.createDataFrame([
        (101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1])
      , (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1])
      , (103, 'Smith', ['avocado'], [10])
    ], ['ID', 'name', 'fruit', 'qty']
)

>>> spark_df.show(5,0)
+---+-----+-----------------------------------------+----------------+
|ID |name |fruit                                    |qty             |
+---+-----+-----------------------------------------+----------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|
|103|Smith|[avocado]                                |[10]            |
+---+-----+-----------------------------------------+----------------+

>>> spark_df.printSchema()
root
 |-- ID: long (nullable = true)
 |-- name: string (nullable = true)
 |-- fruit: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- qty: array (nullable = true)
 |    |-- element: long (containsNull = true)

Set up the SQL statement:

stmt = '''
    transform(array_distinct(fruit), x -> (x, aggregate(
          transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
        , 0
        , (y,z) -> int(y + z)
    ))) AS sum_fruit
'''

>>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0)
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|ID |name |fruit                                    |qty             |sum_fruit                               |
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |[[apple, 19], [orange, 3]]              |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]|
|103|Smith|[avocado]                                |[10]            |[[avocado, 10]]                         |
+---+-----+-----------------------------------------+----------------+----------------------------------------+

Explanation:

  1. Use array_distinct(fruit) to find all distinct entries in the array fruit
  2. transform this new array (with element x) from x to (x, aggregate(..x..))
  3. the above function aggregate(..x..) takes the simple form of summing up all elements in array_T

    aggregate(array_T, 0, (y,z) -> y + z)
    

    where the array_T is from the following transformation:

    transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
    

    which iterate through the array fruit, if the value of fruit[i] = x , then return the corresponding qty[i], otherwise return 0. for example for ID=101, when x = 'orange', it returns an array [0, 0, 3, 0]

jxc
  • 13,553
  • 4
  • 16
  • 34
3

There may be a fancy way to do this using only the API functions on Spark 2.4+, perhaps with some combination of arrays_zip and aggregate, but I can't think of any that don't involve an explode step followed by a groupBy. With that in mind, using a udf may actually be better for you in this case.

I think creating a pandas DataFrame just for the purpose of calling .groupby().sum() is overkill. Furthermore, even if you did do it that way, you'd need to convert the final output to a different data structure because a udf can't return a pandas DataFrame.

Here's one way with a udf using collections.defaultdict:

from collections import defaultdict
from pyspark.sql.functions import udf

def sum_cols_func(frt, qty):
    d = defaultdict(int)
    for x, y in zip(frt, map(int, qty)):
        d[x] += y
    return d.items()

sum_cols = udf(
    lambda x: sum_cols_func(*x),
    ArrayType(
        StructType([StructField("fruit", StringType()), StructField("qty", IntegerType())])
    )
)

Then call this by passing in the fruit and qty columns:

from pyspark.sql.functions import array, col

spark_df.withColumn(
    "Result",
    sum_cols(array([col("fruit"), col("qty")]))
).show(truncate=False)
#+---+----+-----------------------------+-------------+--------------------------+
#|ID |name|fruit                        |qty          |Result                    |
#+---+----+-----------------------------+-------------+--------------------------+
#|101|Mark|[apple, apple, orange, apple]|[16, 2, 3, 1]|[[orange, 3], [apple, 19]]|
#+---+----+-----------------------------+-------------+--------------------------+
pault
  • 41,343
  • 15
  • 107
  • 149
  • I like your solution pault, thanks for your time. However I'm getting this error `Py4JJavaError: An error occurred while calling o3564.showString....Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)`.. does it tell you anything? – crash Jul 31 '19 at 15:01
  • just added `from pyspark.sql.types import ArrayType, StructType, StringType, IntegerType, StructField` – crash Jul 31 '19 at 15:05
  • 1
    Also try forcing Python int type as well in the output `return [(k, int(sum(v))) for k, v in d.items()]`, so that way the result of the udf will be definitely of Python native type. – Richard Nemeth Jul 31 '19 at 15:09
  • @RichardNemeth that's a great point- which makes me wonder: are you sure you're using `__builtin__.sum` and not `numpy.sum` or `pyspark.sql.functions.sum`? [Why you shouldn't use `import *`](https://stackoverflow.com/a/55711135/5858851). **Edit** If you're getting a numpy object, it suggests that you're using `numpy.sum`. – pault Jul 31 '19 at 15:11
  • 1
    yep Richard's solution seemed to solve the problem! – crash Jul 31 '19 at 15:13
  • @crash I updated the answer to using a `defaultdict` of `int` instead of a `list`. that way you can do the sum inside the loop and avoid calling `sum` at the end. This should also fix your issue! – pault Jul 31 '19 at 15:19
  • That's also cleaner! Thanks a lot @pault, this was extremely interesting and formative. – crash Jul 31 '19 at 15:35
1

If you have spark < 2.4, use the follwoing to explode (otherwise check this answer):

df_split = (spark_df.rdd.flatMap(lambda row: [(row.ID, row.name, f, q) for f, q in zip(row.fruit, row.qty)]).toDF(["ID", "name", "fruit", "qty"]))

df_split.show()

Output:

+---+----+------+---+
| ID|name| fruit|qty|
+---+----+------+---+
|101|Mark| apple| 16|
|101|Mark| apple|  2|
|101|Mark|orange|  3|
|101|Mark| apple|  1|
+---+----+------+---+

Then prepare the result you want. First find the aggregated dataframe:

df_aggregated = df_split.groupby('ID', 'fruit').agg(F.sum('qty').alias('qty'))
df_aggregated.show()

Output:

+---+------+---+
| ID| fruit|qty|
+---+------+---+
|101|orange|  3|
|101| apple| 19|
+---+------+---+

And finally change it to the desired format:

df_aggregated.groupby('ID').agg(F.collect_list(F.struct(F.col('fruit'), F.col('qty'))).alias('Result')).show()

Output:

+---+--------------------------+
|ID |Result                    |
+---+--------------------------+
|101|[[orange, 3], [apple, 19]]|
+---+--------------------------+
Ala Tarighati
  • 3,507
  • 5
  • 17
  • 34