0

I am new in spark and want support in solving below problem.I have a data like below:

Country value
India   [1,2,3,4,5]
US  [8,9,10,11,12]
US  [7,6,5,4,3]
India   [8,7,6,5,4]

and output required is the aggregation of element of vector of same country as below in spark.

Output:
Country value
India   [9,9,9,9,9]
US  [15,15,15,15,15]    
  • And... what's the problem? – FZs Mar 30 '19 at 20:00
  • required solution or some guidance to get the desired output in spark. – tanmay verma Mar 30 '19 at 20:20
  • 1
    Hi, welcome to Stack Overflow. Please edit the question to show the source code you have so far, and to explain where you are getting stuck. If you genuinely have no idea where to start, please identify the documentation/articles/examples you have looked at so far and why these did not help. The SO community expects you to have a go, or at least to do some research, before coming here. Thanks. – MandyShaw Mar 30 '19 at 20:33
  • This could help you to get the concept https://stackoverflow.com/questions/54354915/pyspark-aggregate-sum-vector-element-wise – cph_sto Mar 30 '19 at 22:07
  • Do you always have 5 elements in your arrays? Or at least the same size of array everywhere? – Oli Mar 30 '19 at 22:59
  • @Oli: Actually this is the problem i get in interview and i have no idea how to solve it.Secondly, the elements count in an array may vary but all the array must have same count. – tanmay verma Mar 31 '19 at 05:29
  • @cph_sto :Thanks for sharing such a useful article.I will check it and get back with my code. – tanmay verma Mar 31 '19 at 05:33

1 Answers1

0

AFAIK, spark does not provide aggregation functions for arrays. Therefore, if the size of the arrays is fixed, you can create one column per element of the array, aggregate and then re create the array.

In a generic way, this could go as follows:

from pyspark.sql.functions import col, sum

# first, let's get the size of the array
size = len(df.first()['value'])

# Then, summing each element separately:
aggregation = df.groupBy("country")\
    .agg(*[sum(df.value.getItem(i)).alias("v"+str(i)) for i in range(size)])
aggregation.show()
+-------+---+---+---+---+---+                                                   
|country| v0| v1| v2| v3| v4|
+-------+---+---+---+---+---+
|  India|  9|  9|  9|  9|  9|
|     US| 15| 15| 15| 15| 15|
+-------+---+---+---+---+---+


# Finally, we recreate the array
result = aggregation.select(df.country,\
    functions.array(*[col("v"+str(i)) for i in range(size)]).alias("value"))
result.show()
+-------+--------------------+
|country|               value|
+-------+--------------------+
|  India|     [9, 9, 9, 9, 9]|
|     US|[15, 15, 15, 15, 15]|
+-------+--------------------+
Oli
  • 9,766
  • 5
  • 25
  • 46