0

I am using the collect_set method on a DataFrame and adding 3 columns.

My df is as below:

id  acc_no  acc_name  cust_id    
1    111      ABC       88    
1    222      XYZ       99

Below is the code snippet:

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy('id').orderBy('acc_no')
df1 = df.withColumn(
    'cust_id_new',
    F.collect_set(cust_id).over(w)
).withColumn(
    'acc_no_new',
    F.collect_set(acc_no).over(w)
).withColumn(
    'acc_name_new',
    F.collect_set(acc_name).over(w)
).drop('cust_id').drop('acc_no').drop('acc_name')

In this case, my output is as follows:

id    acc_no     acc_name    cust_id   
1    [111,222]  [XYZ,ABC]    [88,99]

So here, the acc_no and cust_id are correct, but the order of acc_name is incorrect. acc_no 111 corresponds to acc_name ABC, but we are getting XYZ.

Can someone please let me know why this is happening and what would be the solution ?

I suspect this issue is occurring for string column only, but i might be wrong. Please help...

This is similar to below thread, but I am still getting an error.

How to maintain sort order in PySpark collect_list and collect multiple lists

pault
  • 41,343
  • 15
  • 107
  • 149
Suyog
  • 21
  • 4
  • 4
    `set`'s are unordered in nature – pissall Oct 28 '19 at 16:35
  • 1
    what version of pyspark? in 2.4+ you might be able to use use `collect_list` with [`array_distinct`](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.array_distinct). Or zip the arrays before sorting. – pault Oct 28 '19 at 17:01
  • I am using spark 2.3. – Suyog Oct 29 '19 at 05:54
  • 1 thing I would like to mention here is I have typecasted all the columns to string for a specific reason. – Suyog Oct 29 '19 at 07:17

1 Answers1

0

We can use row_number function within the id column and use collect_list & sorted_array to preserve the order.

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy('id').orderBy('cust_id')
df1 = df.withColumn(
    'rn',
    row_number(cust_id).over(w)
).groupBy("id").agg(sort_array(collect_list(struct('rn','acc_no','acc_name','cust_id'))).alias('data')) \
.withColumn('grp_acc_no',col('data.acc_no')) \
.withColumn('grp_acc_name',col('data.acc_name')) \
.withColumn('grp_cust_id',col('data.cust_id')) \
.drop('data','acc_no','acc_name','cust_id').show(truncate=False) `````