44

I have a table of two string type columns (username, friend) and for each username, I want to collect all of its friends on one row, concatenated as strings. For example: ('username1', 'friends1, friends2, friends3')

I know MySQL does this with GROUP_CONCAT. Is there any way to do this with Spark SQL?

Nick Chammas
  • 11,843
  • 8
  • 56
  • 115
Zahra I.S
  • 695
  • 1
  • 10
  • 20
  • If you are using Spark 2.4+, you can do this with a combination of `collect_list()` and `array_join()`. No need for UDFs. For the details, [see my answer](https://stackoverflow.com/a/59472764/877069). – Nick Chammas Jul 10 '20 at 18:23

10 Answers10

47

Before you proceed: This operations is yet another another groupByKey. While it has multiple legitimate applications it is relatively expensive so be sure to use it only when required.


Not exactly concise or efficient solution but you can use UserDefinedAggregateFunction introduced in Spark 1.5.0:

object GroupConcat extends UserDefinedAggregateFunction {
    def inputSchema = new StructType().add("x", StringType)
    def bufferSchema = new StructType().add("buff", ArrayType(StringType))
    def dataType = StringType
    def deterministic = true 

    def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, ArrayBuffer.empty[String])
    }

    def update(buffer: MutableAggregationBuffer, input: Row) = {
      if (!input.isNullAt(0)) 
        buffer.update(0, buffer.getSeq[String](0) :+ input.getString(0))
    }

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
    }

    def evaluate(buffer: Row) = UTF8String.fromString(
      buffer.getSeq[String](0).mkString(","))
}

Example usage:

val df = sc.parallelize(Seq(
  ("username1", "friend1"),
  ("username1", "friend2"),
  ("username2", "friend1"),
  ("username2", "friend3")
)).toDF("username", "friend")

df.groupBy($"username").agg(GroupConcat($"friend")).show

## +---------+---------------+
## | username|        friends|
## +---------+---------------+
## |username1|friend1,friend2|
## |username2|friend1,friend3|
## +---------+---------------+

You can also create a Python wrapper as shown in Spark: How to map Python with Scala or Java User Defined Functions?

In practice it can be faster to extract RDD, groupByKey, mkString and rebuild DataFrame.

You can get a similar effect by combining collect_list function (Spark >= 1.6.0) with concat_ws:

import org.apache.spark.sql.functions.{collect_list, udf, lit}

df.groupBy($"username")
  .agg(concat_ws(",", collect_list($"friend")).alias("friends"))
Community
  • 1
  • 1
zero323
  • 322,348
  • 103
  • 959
  • 935
  • 1
    What If I want to use it In SQL How can I register this UDF in Spark SQL? – Murtaza Kanchwala Jun 06 '16 at 11:50
  • @MurtazaKanchwala [There is `register` method which accepts UDAFS](https://github.com/apache/spark/blob/37c617e4f580482b59e1abbe3c0c27c7125cf605/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala#L63-L69) so it should work as standard UDF. – zero323 Jun 06 '16 at 19:16
  • @zero323 any approach to do the same in spark sql 1.4.1 – undefined_variable Oct 19 '16 at 06:55
  • Can't you remove ` UTF8String.fromString()` in evaluate function? – Danny Wang Oct 30 '16 at 03:29
  • This is a v. good solution. I tried it after a couple of modifications and worked fine *except* I was getting compatibility issues with the resulting DF. I could not compare the columns produced with other columns without getting UTF exceptions. I changed to converting the DF to an RDD; doing what I wanted and then converting it back to a DF. This fixed all problems and, in addition, the solution was 10x faster. I think that it is safe to say that `udfs` should be avoided if and when possible. – Christos Hadjinikolis Dec 14 '16 at 11:27
25

You can try the collect_list function

sqlContext.sql("select A, collect_list(B), collect_list(C) from Table1 group by A

Or you can regieter a UDF something like

sqlContext.udf.register("myzip",(a:Long,b:Long)=>(a+","+b))

and you can use this function in the query

sqlConttext.sql("select A,collect_list(myzip(B,C)) from tbl group by A")
iec2011007
  • 1,828
  • 3
  • 24
  • 38
  • 1
    `collect_set` will work too, will return only unique values – Shir May 02 '18 at 08:11
  • `collect_list` and `collect_set` are awesome Spark SQL functions! [spark-sql > sql-ref-functions-builtin](https://learn.microsoft.com/en-us/azure/databricks/spark/latest/spark-sql/language-manual/sql-ref-functions-builtin) – SherlockSpreadsheets Mar 16 '21 at 21:30
24

In Spark 2.4+ this has become simpler with the help of collect_list() and array_join().

Here's a demonstration in PySpark, though the code should be very similar for Scala too:

from pyspark.sql.functions import array_join, collect_list

friends = spark.createDataFrame(
    [
        ('jacques', 'nicolas'),
        ('jacques', 'georges'),
        ('jacques', 'francois'),
        ('bob', 'amelie'),
        ('bob', 'zoe'),
    ],
    schema=['username', 'friend'],
)

(
    friends
    .orderBy('friend', ascending=False)
    .groupBy('username')
    .agg(
        array_join(
            collect_list('friend'),
            delimiter=', ',
        ).alias('friends')
    )
    .show(truncate=False)
)

In Spark SQL the solution is likewise:

SELECT
    username,
    array_join(collect_list(friend), ', ') AS friends
FROM friends
GROUP BY username;

The output:

+--------+--------------------------+
|username|friends                   |
+--------+--------------------------+
|jacques |nicolas, georges, francois|
|bob     |zoe, amelie               |
+--------+--------------------------+

This is similar to MySQL's GROUP_CONCAT() and Redshift's LISTAGG().

Nick Chammas
  • 11,843
  • 8
  • 56
  • 115
12

Here is a function you can use in PySpark:

import pyspark.sql.functions as F

def group_concat(col, distinct=False, sep=','):
    if distinct:
        collect = F.collect_set(col.cast(StringType()))
    else:
        collect = F.collect_list(col.cast(StringType()))
    return F.concat_ws(sep, collect)


table.groupby('username').agg(F.group_concat('friends').alias('friends'))

In SQL:

select username, concat_ws(',', collect_list(friends)) as friends
from table
group by username
rikturr
  • 409
  • 6
  • 7
4

-- the spark SQL resolution with collect_set

SELECT id, concat_ws(', ', sort_array( collect_set(colors))) as csv_colors
FROM ( 
  VALUES ('A', 'green'),('A','yellow'),('B', 'blue'),('B','green') 
) as T (id, colors)
GROUP BY id
Krzysztof Madej
  • 32,704
  • 10
  • 78
  • 107
3

One way to do it with pyspark < 1.6, which unfortunately doesn't support user-defined aggregate function:

byUsername = df.rdd.reduceByKey(lambda x, y: x + ", " + y)

and if you want to make it a dataframe again:

sqlContext.createDataFrame(byUsername, ["username", "friends"])

As of 1.6, you can use collect_list and then join the created list:

from pyspark.sql import functions as F
from pyspark.sql.types import StringType
join_ = F.udf(lambda x: ", ".join(x), StringType())
df.groupBy("username").agg(join_(F.collect_list("friend").alias("friends"))
Kamil Sindi
  • 21,782
  • 19
  • 96
  • 120
2

Language: Scala Spark version: 1.5.2

I had the same issue and also tried to resolve it using udfs but, unfortunately, this has led to more problems later in the code due to type inconsistencies. I was able to work my way around this by first converting the DF to an RDD then grouping by and manipulating the data in the desired way and then converting the RDD back to a DF as follows:

val df = sc
     .parallelize(Seq(
        ("username1", "friend1"),
        ("username1", "friend2"),
        ("username2", "friend1"),
        ("username2", "friend3")))
     .toDF("username", "friend")

+---------+-------+
| username| friend|
+---------+-------+
|username1|friend1|
|username1|friend2|
|username2|friend1|
|username2|friend3|
+---------+-------+

val dfGRPD = df.map(Row => (Row(0), Row(1)))
     .groupByKey()
     .map{ case(username:String, groupOfFriends:Iterable[String]) => (username, groupOfFriends.mkString(","))}
     .toDF("username", "groupOfFriends")

+---------+---------------+
| username| groupOfFriends|
+---------+---------------+
|username1|friend2,friend1|
|username2|friend3,friend1|
+---------+---------------+
Christos Hadjinikolis
  • 2,099
  • 3
  • 20
  • 46
0

Below python-based code that achieves group_concat functionality.

Input Data:

Cust_No,Cust_Cars

1, Toyota

2, BMW

1, Audi

2, Hyundai

from pyspark.sql import SparkSession
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
import pyspark.sql.functions as F

spark = SparkSession.builder.master('yarn').getOrCreate()

# Udf to join all list elements with "|"
def combine_cars(car_list,sep='|'):
  collect = sep.join(car_list)
  return collect

test_udf = udf(combine_cars,StringType())
car_list_per_customer.groupBy("Cust_No").agg(F.collect_list("Cust_Cars").alias("car_list")).select("Cust_No",test_udf("car_list").alias("Final_List")).show(20,False)

Output Data: Cust_No, Final_List

1, Toyota|Audi

2, BMW|Hyundai

Akshay Patel
  • 101
  • 2
0

You can also use Spark SQL function collect_list and after you will need to cast to string and use the function regexp_replace to replace the special characters.

regexp_replace(regexp_replace(regexp_replace(cast(collect_list((column)) as string), ' ', ''), ',', '|'), '[^A-Z0-9|]', '')

it's an easier way.

Kevin Giediel
  • 21
  • 1
  • 5
0

Higher order function concat_ws() and collect_list() can be a good alternative along with groupBy()

import pyspark.sql.functions as F
    
df_grp = df.groupby("agg_col").agg(F.concat_ws("#;", F.collect_list(df.time)).alias("time"), F.concat_ws("#;", F.collect_list(df.status)).alias("status"), F.concat_ws("#;", F.collect_list(df.llamaType)).alias("llamaType"))

Sample Output

+-------+------------------+----------------+---------------------+
|agg_col|time              |status          |llamaType            |
+-------+------------------+----------------+---------------------+
|1      |5-1-2020#;6-2-2020|Running#;Sitting|red llama#;blue llama|
+-------+------------------+----------------+---------------------+
dsk
  • 1,863
  • 2
  • 10
  • 13