1

I have a dataframe as shown below:

+-----+------------------------+
|Index|   finalArray           |
+-----+------------------------+
|1    |[0, 2, 0, 3, 1, 4, 2, 7]|
|2    |[0, 4, 4, 3, 4, 2, 2, 5]|
+-----+------------------------+

I want to break the array into chunks of 2 and then find the sum of each chunks and store the resultant array in the column finalArray. It will look like below:

+-----+---------------------+
|Index|    finalArray       |
+-----+---------------------+
|1    |[2, 3, 5, 9]         |
|2    |[4, 7, 6, 7]         |
+-----+---------------------+

I am able to do it by creating an UDF but looking for an better and optimised way. Preferably if I can handle it using a withColumn and passing flagArray to do it without having to write an UDF.

@udf(ArrayType(DoubleType()))
def aggregate(finalArray,chunkSize):
   n = int(chunkSize)
   aggsum = []
   final = [finalArray[i * n:(i + 1) * n] for i in range((len(finalArray) + n - 1) // n )]
   for item in final:
      agg = 0
      for j in item:
         agg += j
         aggsum.append(agg)
   return aggsum

I am not able to use the below expression in UDF hence I used loops

[sum(finalArray[x:x+2]) for x in range(0, len(finalArray), chunkSize)]
ZygD
  • 22,092
  • 39
  • 79
  • 102
Saikat
  • 403
  • 1
  • 7
  • 19

2 Answers2

5

For spark 2.4+, you can try sequence + transform:

from pyspark.sql.function import expr

df = spark.createDataFrame([
  (1, [0, 2, 0, 3, 1, 4, 2, 7]),
  (2, [0, 4, 4, 3, 4, 2, 2, 5])
], ["Index", "finalArray"])

df.withColumn("finalArray", expr("""
    transform(
      sequence(0,ceil(size(finalArray)/2)-1), 
      i -> finalArray[2*i] + ifnull(finalArray[2*i+1],0))
 """)).show(truncate=False)
+-----+------------+
|Index|finalArray  |
+-----+------------+
|1    |[2, 3, 5, 9]|
|2    |[4, 7, 6, 7]|
+-----+------------+

For a chunk-size of any N, use aggregate function to do the sub-totals:

N = 3

sql_expr = """
    transform(
      /* create a sequence from 0 to number_of_chunks-1 */
      sequence(0,ceil(size(finalArray)/{0})-1),
      /* iterate the above sequence */
      i -> 
        /* create a sequence from 0 to chunk_size-1 
           calculate the sum of values containing every chunk_size items by their indices
         */
        aggregate(
          sequence(0,{0}-1),
          0L, 
          (acc, y) -> acc + ifnull(finalArray[i*{0}+y],0)
        )
    )
"""
df.withColumn("finalArray", expr(sql_expr.format(N))).show()                                                        
+-----+----------+
|Index|finalArray|
+-----+----------+
|    1| [2, 8, 9]|
|    2| [8, 9, 7]|
+-----+----------+
jxc
  • 13,553
  • 4
  • 16
  • 34
  • is there any way i can find Max value for each chunks using sequence + transform. I can do it using udf but since my dataframe contains around 12-18 million rows want to avoid udf as much as possible. Any help is appreciated – Saikat Apr 21 '20 at 03:36
  • check: `sql_expr = "transform(sequence(0,ceil(size(finalArray)/{0})-1), i -> array_max(slice(finalArray,i*{0}+1,{0})))".format(N)` – jxc Apr 21 '20 at 03:47
  • 1
    it works like charm. thanks a lot for your promt and efficient response – Saikat Apr 21 '20 at 03:54
  • done, didn't realize haven't done that already. Thanks a ton again – Saikat Apr 21 '20 at 03:57
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/212127/discussion-between-saikat-and-jxc). – Saikat Apr 21 '20 at 05:08
  • is there a way to make the solution provided by you to work if my array contains decimal values. In the above code 0L works with integers and 0D for doubles but I have array elements of decimal types of precision(38,15) and after doing the sum I would like the value to be in the same type and precision if possible.Also it would be great to have some documentation to have a better understanding – Saikat Jun 15 '20 at 14:04
  • just change `0L` to `cast(0 as decimal(38,15))` – jxc Jun 15 '20 at 14:16
  • its throwing AnalysisException: "cannot resolve 'aggregate(sequence(0, (2 - 1)), CAST(0 AS DECIMAL(38,15)), My expression is ``` transform_expr = """transform(sequence(0,ceil(size(finalArray)/{0})-1),i ->aggregate(sequence(0,{0}-1),cast(0 as decimal(38,15)),(acc, y) -> acc + ifnull(finalArray[i*{0}+y],0)))""" ``` – Saikat Jun 15 '20 at 14:25
  • you can try also force the datatype in the 3rd part of the aggregate function: `(acc, y) -> cast(acc + ifnull(finalArray[i*{0}+y],0) as decimal(38,15))` – jxc Jun 15 '20 at 14:49
  • thanks @jxc for your prompt response. Thanks a lot as always – Saikat Jun 15 '20 at 20:49
  • @jxc i have try to connect with you , i have changed the spark version to 2.4 for the question : https://stackoverflow.com/questions/64660047/how-to-use-windowing-functions-efficiently-to-decide-next-n-number-of-rows-based , please undelete your answer. now code is running. – Smith Nov 08 '20 at 14:13
1

Here is a slightly different version of @jxc's solution using slice function with transform and aggregate functions.

The logic is for each element of the array we check if its index is a multiple of chunk size and use slice to get a subarray of chunk size. With aggregate we sum the elements of each sub-array. Finally using filter to remove nulls (corresponding to indexes that do not satisfy i % chunk = 0.

chunk = 2

transform_expr = f"""
filter(transform(finalArray, 
                 (x, i) -> IF (i % {chunk} = 0, 
                               aggregate(slice(finalArray, i+1, {chunk}), 0L, (acc, y) -> acc + y),
                               null
                              )
                ),
      x -> x is not null)
"""

df.withColumn("finalArray", expr(transform_expr)).show()

#+-----+------------+
#|Index|  finalArray|
#+-----+------------+
#|    1|[2, 3, 5, 9]|
#|    2|[4, 7, 6, 7]|
#+-----+------------+
blackbishop
  • 30,945
  • 11
  • 55
  • 76
  • is there a way to make the solution provided by you to work if my array contains decimal values. In the above code 0L works with integers and 0D for doubles but I have array elements of decimal types of precision(38,15) and after doing the sum I would like the value to be in the same type and precision if possible.Also it would be great to have some documentation to have a better understanding – Saikat Jun 15 '20 at 14:05
  • @Saikat use `CAST(0 AS DECIMAL(38,15))` as the zero-value instead of `0L`. And make sure to change `acc + y` to `CAST(acc + y AS DECIMAL(38,15))` in the third argument of aggregate function, to avoid type mismatch errors. – blackbishop Jun 15 '20 at 14:31
  • thanks @blackbishop for your prompt response. Thanks a lot – Saikat Jun 15 '20 at 20:48