90

I am trying to create a new column of lists in Pyspark using a groupby aggregation on existing set of columns. An example input data frame is provided below:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

The expected output is:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

The values within a list are sorted by the date.

I tried using collect_list as follows:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

But collect_list doesn't guarantee order even if I sort the input data frame by date before aggregation.

Could someone help on how to do aggregation by preserving the order based on a second (date) variable?

mtoto
  • 23,919
  • 4
  • 58
  • 71
Ravi
  • 3,223
  • 7
  • 37
  • 49

10 Answers10

136
from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(w)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))

Window examples provided by users often don't really explain what is going on so let me dissect it for you.

As you know, using collect_list together with groupBy will result in an unordered list of values. This is because depending on how your data is partitioned, Spark will append values to your list as soon as it finds a row in the group. The order then depends on how Spark plans your aggregation over the executors.

A Window function allows you to control that situation, grouping rows by a certain value so you can perform an operation over each of the resultant groups:

w = Window.partitionBy('id').orderBy('date')
  • partitionBy - you want groups/partitions of rows with the same id
  • orderBy - you want each row in the group to be sorted by date

Once you have defined the scope of your Window - "rows with the same id, sorted by date" -, you can use it to perform an operation over it, in this case, a collect_list:

F.collect_list('value').over(w)

At this point you created a new column sorted_list with an ordered list of values, sorted by date, but you still have duplicated rows per id. To trim out the duplicated rows you want to groupBy id and keep the max value in for each group:

.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
TMichel
  • 4,336
  • 9
  • 44
  • 67
  • 27
    This should be the accepted answer due to the usage of Spark-basic functions - Very nice! – Markus Apr 04 '19 at 11:15
  • How does `max` work in this context where the values are arrays? – nciao Sep 24 '19 at 14:35
  • @nciao do you mean on a `ArrayType>` column? My guess is exactly the same way as in a `ArrayType<*>`, as `max` by documentation _"returns the maximum value of the expression in a group"_. – TMichel Sep 26 '19 at 07:48
  • I meant on the Sorted_List column, which would be of type ArrayType. (my understanding is that you’re telling spark to aggregate the group, by taking the maximum value of that particular array in each group). – nciao Sep 26 '19 at 11:01
  • which led me to wonder “well, what is “max” for an array? or array? “ (in my particular case) – nciao Sep 26 '19 at 11:02
  • 10
    The max is needed, because for the same "id", a list is created for each row, in the sorted order: [10], then [10, 5], then [10, 5, 15], then [10, 5, 15, 20] for id=1. Taking the max of lists takes the longest one (here [10, 5, 15, 20]). – CharlesG Oct 17 '19 at 09:58
  • 7
    What are the memory implications of this? Is this approach better than the accepted answer when we are dealing with chaining of billion+ events when a chain can have up to 10.000 items in collected list? – Hedrack Dec 03 '19 at 19:17
  • 5
    Doesn't this expansive ? If I have 10 million group , each group have 24 element . `F.collect_list('value').over(w)` would create a new column size from 1 to 24 , 10million * 24 times . Then do another group by just get larget row from each group. – Mithril Apr 24 '20 at 07:23
  • 3
    This doesn't work if you're using `collect_set` instead of `collect_list`. – Steve May 05 '20 at 20:42
  • 2
    This doesn't seem feasible if a single id has large numbers of rows. If id=1 has n rows of say long data types, then before the groupby you'll need to store (n^2)/2 longs, e.g., if n=10**7 then you'll need 400 terabytes for that column alone, and if n=10**8 you'd need 40 petabytes, etc – mathisfun Oct 21 '20 at 14:31
66

You can use sort_array function. If you collect both dates and values as a list, you can sort the resulting column using sort_array and keep only the columns you require.

import pyspark.sql.functions as F

grouped_df = (
    input_df
    .groupby("id")
    .agg(
        F.sort_array(F.collect_list(F.struct("date", "value")))
        .alias("collected_list")
    )
    .withColumn("sorted_list", F.col("collected_list.value"))
    .drop("collected_list")
).show(truncate=False)

+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+
koPytok
  • 3,453
  • 1
  • 14
  • 29
  • 13
    Thanks a lot. I find the Window.patitionBy and then get max row cannot perform on large data. Your solution is ~200 times faster. – Phongsakorn Jul 29 '20 at 11:10
  • Yes, this is way faster in scala as well: grouped_df = input_df.groupBy("id").agg(sort_array(collect_list(struct("date", "value"))).alias("collected_list")).withColumn("sorted_list", col("collected_list.value")) .drop("collected_list") – mathisfun Oct 21 '20 at 15:07
  • 2
    I didn't know Spark understands this notion collected_list.value as an array of corresponding field values. Nice! – Alexander Pivovarov Feb 05 '21 at 21:18
51

If you collect both dates and values as a list, you can sort the resulting column according to date using and udf, and then keep only the values in the result.

import operator
import pyspark.sql.functions as F

# create list column
grouped_df = input_df.groupby("id") \
               .agg(F.collect_list(F.struct("date", "value")) \
               .alias("list_col"))

# define udf
def sorter(l):
  res = sorted(l, key=operator.itemgetter(0))
  return [item[1] for item in res]

sort_udf = F.udf(sorter)

# test
grouped_df.select("id", sort_udf("list_col") \
  .alias("sorted_list")) \
  .show(truncate = False)
+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+
abeboparebop
  • 7,396
  • 6
  • 37
  • 46
mtoto
  • 23,919
  • 4
  • 58
  • 71
  • Thanks for the detailed example...I just tried it on a bigger data of a few millions and I am getting the exact same sequence as that of collect_list...Is there a way to explain why this could be happening? Also, checked that collect_list only seems to mess up those cases with multiple values within a date...Does it mean collect_list also maintains the order? – Ravi Oct 05 '17 at 15:26
  • 1
    In your code, you sort the entire dataset before collect_list() so yes. But this is not necessary, it is more efficient to sort the resulting list of tuples after collecting both date and value in a list. – mtoto Oct 05 '17 at 15:38
  • Just to clarify...sorting the column and using collect_list on the sorted column would preserve the order? – Ravi Oct 05 '17 at 15:42
  • Depends how your data is partitioned accross the nodes, but the udf method should always guarantee correct order though because we sort per row. – mtoto Oct 05 '17 at 15:44
  • Can you shed some light on how partitioning affects the result in the first case? In my case, all the million+ records have the same sequence between the two methods. I didn't specify any special partitioning. – Ravi Oct 05 '17 at 15:46
  • 2
    Order in distributed systems is often meaningless, so correct order cannot be guaranteed unless the values for each id are in one partition. – mtoto Oct 05 '17 at 15:57
  • Is there a way to figure out how the data is partitioned? – Ravi Oct 05 '17 at 16:09
  • Shouldn't it be `sort_udf = F.udf(sorter, ArrayType(IntegerType()))` given that `udf` uses `StringType` as the default return type if not specified, and `sorter` returns a list of integers? – Narvarth Aug 06 '19 at 18:40
  • this fails if we use `collect_set` – Hardik Gupta Oct 10 '19 at 07:33
  • 1
    This answer is rather old by now, I think with the introduction of `array_sort` as the other answers describe, that is the best approach as it doesn't require the overhead of a UDF. – RvdV Jan 21 '21 at 13:41
22

The question was for PySpark but might be helpful to have it also for Scala Spark.

Let's prepare test dataframe:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{ Window, UserDefinedFunction}

import java.sql.Date
import java.time.LocalDate

val spark: SparkSession = ...

// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
  (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
  (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
  (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
  (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
  (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
  (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
  (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)

// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
  .toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id|      date|value|
//+---+----------+-----+
//|  1|2014-01-03|   10|
//|  1|2014-01-04|    5|
//|  1|2014-01-05|   15|
//|  1|2014-01-06|   20|
//|  2|2014-02-10|  100|
//|  2|2014-02-11|  500|
//|  2|2014-02-15| 1500|
//+---+----------+-----+

Use UDF

// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
  .agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id|          date_value|
// +---+--------------------+
// |  1|[[2014-01-03,10],...|
// |  2|[[2014-02-10,100]...|
// +---+--------------------+

// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
  rows.map { case Row(date: Date, value: Int) => (date, value) }
    .sortBy { case (date, value) => date }
    .map { case (date, value) => value }
})

// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id|      value_list|
// +---+----------------+
// |  1| [10, 5, 15, 20]|
// |  2|[100, 500, 1500]|
// +---+----------------+

Use Window

val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id|      date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//|  1|2014-01-03|   10|                 [10]|
//|  1|2014-01-04|    5|              [10, 5]|
//|  1|2014-01-05|   15|          [10, 5, 15]|
//|  1|2014-01-06|   20|      [10, 5, 15, 20]|
//|  2|2014-02-10|  100|                [100]|
//|  2|2014-02-11|  500|           [100, 500]|
//|  2|2014-02-15| 1500|     [100, 500, 1500]|
//+---+----------+-----+---------------------+

val r2 = sortedDf.groupBy(col("id"))
  .agg(max("values_sorted_by_date").as("value_list")) 
r2.show()
//+---+----------------+
//| id|      value_list|
//+---+----------------+
//|  1| [10, 5, 15, 20]|
//|  2|[100, 500, 1500]|
//+---+----------------+
Artavazd Balayan
  • 2,353
  • 1
  • 16
  • 25
  • is it possible to accomplish this without either a window or udf via combination of explode, group by, order by? – 219CID Sep 23 '21 at 23:21
7

To make sure the sort is done for each id, we can use sortWithinPartitions:

from pyspark.sql import functions as F
ordered_df = (
    input_df
        .repartition(input_df.id)
        .sortWithinPartitions(['date'])


)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
ShadyStego
  • 93
  • 1
  • 5
  • 7
    The group by step is happening after the sort. Will the sort order be retained in group by step? There is no such guarantee AFAIK – nish Sep 04 '18 at 09:40
4

In the Spark SQL world the answer to this would be:

SELECT 
browser, max(list)
from (
  SELECT
    id,
    COLLECT_LIST(value) OVER (PARTITION BY id ORDER BY date DESC) as list
  FROM browser_count
  GROUP BYid, value, date) 
Group by browser;
Fardin Abdi
  • 1,284
  • 15
  • 20
2

I tried TMichel approach and didn't work for me. When I performed the max aggregation I wasn't getting back the highest value of the list. So what worked for me is the following:

def max_n_values(df, key, col_name, number):
    '''
    Returns the max n values of a spark dataframe
    partitioned by the key and ranked by the col_name
    '''
    w2 = Window.partitionBy(key).orderBy(f.col(col_name).desc())
    output = df.select('*',
                       f.row_number().over(w2).alias('rank')).filter(
                           f.col('rank') <= number).drop('rank')
    return output

def col_list(df, key, col_to_collect, name, score):
    w = Window.partitionBy(key).orderBy(f.col(score).desc())

    list_df = df.withColumn(name, f.collect_set(col_to_collect).over(w))
    size_df = list_df.withColumn('size', f.size(name))
    output = max_n_values(df=size_df,
                               key=key,
                               col_name='size',
                               number=1)
    return output
nvarelas
  • 31
  • 5
  • 1
    I think it can be useful a little explanation how this works for you and the difference with the accepted answer – Sfili_81 Jan 09 '20 at 14:46
  • 1
    When I tried Tmichel's approach the max value didn't work. I wasn't getting back the list with the most elements, I was getting back random lists. So what I did is I created a new column which measures the size and got the highest values of eahc partition. Hope that make sense! – nvarelas Jan 09 '20 at 15:04
2

As of Spark 2.4, the collect_list(ArrayType) created in @mtoto's answer can be post-processed by using SparkSQL's builtin functions transform and array_sort (no need for udf):

from pyspark.sql.functions import collect_list, expr, struct

df.groupby('id') \
  .agg(collect_list(struct('date','value')).alias('value_list')) \
  .withColumn('value_list', expr('transform(array_sort(value_list), x -> x.value)')) \
  .show()
+---+----------------+
| id|      value_list|
+---+----------------+
|  1| [10, 5, 15, 20]|
|  2|[100, 500, 1500]|
+---+----------------+ 

Note: if descending order is required change array_sort(value_list) to sort_array(value_list, False)

Caveat: array_sort() and sort_array() won't work if items(in collect_list) must be sorted by multiple fields(columns) in a mixed order, i.e. orderBy('col1', desc('col2')).

jxc
  • 13,553
  • 4
  • 16
  • 34
2

if you want to use spark sql here is how you can achieve this. Assuming the table name (or temporary view) is temp_table.

select
t1.id,
collect_list(value) as value_list
(Select * from temp_table order by id,date) t1
group by 1
sushmit
  • 4,369
  • 2
  • 35
  • 38
0

Complementing what ShadyStego said, I've been testing the usage of sortWithinPartitions and GroupBy on Spark, finding out it performs quite better than Window functions or UDF. Still, there is an issue with a missordering once per partition when using this method, but it can be easily solved. I show it here Spark (pySpark) groupBy misordering first element on collect_list.

This method is specially useful on large DataFrames, but a large number of partitions may be needed if you are short on driver memory.

kubote
  • 86
  • 2
  • 6