2

I have a pyspark dataframe. For example,

d= hiveContext.createDataFrame([("A", 1), ("B", 2), ("D", 3), ("D", 3),  ("A", 4), ("D", 3)],["Col1", "Col2"])

+----+----+
|Col1|Col2|
+----+----+
|   A|   1|
|   B|   2|
|   D|   3|
|   D|   3|
|   A|   4|
|   D|   3|
+----+----+

I want to group by Col1 and then create a list of Col2. I need to flatten the groups. I do have a lot of columns.

+----+----------+
|Col1|      Col2|
+----+----------+
|   A|   [1,4]  |
|   B|   [2]    |
|   D|   [3,3,3]|
+----+----------+
pault
  • 41,343
  • 15
  • 107
  • 149
Bryce Ramgovind
  • 3,127
  • 10
  • 41
  • 72

2 Answers2

7

You can do a groupBy() and use collect_list() as your aggregate function:

import pyspark.sql.functions as f
d.groupBy('Col1').agg(f.collect_list('Col2').alias('Col2')).show()
#+----+---------+
#|Col1|     Col2|
#+----+---------+
#|   B|      [2]|
#|   D|[3, 3, 3]|
#|   A|   [1, 4]|
#+----+---------+

Update

If you had multiple columns to combine, you could use collect_list() on each, and the combine the resulting lists using struct() and udf(). Consider the following example:

Create Dummy Data

from operator import add
import pyspark.sql.functions as f

# create example dataframe
d = sqlcx.createDataFrame(
    [
        ("A", 1, 10),
        ("B", 2, 20),
        ("D", 3, 30),
        ("D", 3, 10),
        ("A", 4, 20),
        ("D", 3, 30)
    ],
    ["Col1", "Col2", "Col3"]
)

Collect Desired Columns into lists

Suppose you had a list of columns you wanted to collect into a list. You could do the following:

cols_to_combine = ['Col2', 'Col3']
d.groupBy('Col1').agg(*[f.collect_list(c).alias(c) for c in cols_to_combine]).show()
#+----+---------+------------+
#|Col1|     Col2|        Col3|
#+----+---------+------------+
#|   B|      [2]|        [20]|
#|   D|[3, 3, 3]|[30, 10, 30]|
#|   A|   [4, 1]|    [20, 10]|
#+----+---------+------------+

Combine Resultant Lists into one Column

Now we want to combine the list columns into one list. If we use struct(), we will get the following:

d.groupBy('Col1').agg(*[f.collect_list(c).alias(c) for c in cols_to_combine])\
    .select('Col1', f.struct(*cols_to_combine).alias('Combined'))\
    .show(truncate=False)
#+----+------------------------------------------------+
#|Col1|Combined                                        |
#+----+------------------------------------------------+
#|B   |[WrappedArray(2),WrappedArray(20)]              |
#|D   |[WrappedArray(3, 3, 3),WrappedArray(10, 30, 30)]|
#|A   |[WrappedArray(1, 4),WrappedArray(10, 20)]       |
#+----+------------------------------------------------+

Flatten Wrapped Arrays

Almost there. We just need to combine the WrappedArrays. We can achieve this with a udf():

combine_wrapped_arrays = f.udf(lambda val: reduce(add, val), ArrayType(IntegerType()))
d.groupBy('Col1').agg(*[f.collect_list(c).alias(c) for c in cols_to_combine])\
    .select('Col1', combine_wrapped_arrays(f.struct(*cols_to_combine)).alias('Combined'))\
    .show(truncate=False)
#+----+---------------------+
#|Col1|Combined             |
#+----+---------------------+
#|B   |[2, 20]              |
#|D   |[3, 3, 3, 30, 10, 30]|
#|A   |[1, 4, 10, 20]       |
#+----+---------------------+

References


Update 2

A simpler way, without having to deal with WrappedArrays:

from operator import add

combine_udf = lambda cols: f.udf(
    lambda *args: reduce(add, args),
    ArrayType(IntegerType())
)

d.groupBy('Col1').agg(*[f.collect_list(c).alias(c) for c in cols_to_combine])\
    .select('Col1', combine_udf(cols_to_combine)(*cols_to_combine).alias('Combined'))\
    .show(truncate=False)
#+----+---------------------+
#|Col1|Combined             |
#+----+---------------------+
#|B   |[2, 20]              |
#|D   |[3, 3, 3, 30, 10, 30]|
#|A   |[1, 4, 10, 20]       |
#+----+---------------------+

Note: This last step only works if the datatypes for all of the columns are the same. You can not use this function to combine wrapped arrays with mixed types.

pault
  • 41,343
  • 15
  • 107
  • 149
  • What modifications would i have to make to include multiple columns i.e. flatten Col3, Col4 etc. – Bryce Ramgovind Feb 05 '18 at 15:46
  • I just want to point out that that using `reduce(add, ..)` is an anti-pattern (I did not know this at the time I answered). A better way would be to use `chain.from_iterable`. More info [here](https://stackoverflow.com/q/41772054/5858851). – pault Jul 02 '18 at 21:50
6

from spark 2.4 you can use pyspark.sql.functions.flatten

import pyspark.sql.functions as f
df.groupBy('Col1').agg(f.flatten(f.collect_list('Col2')).alias('Col2')).show()
Sovos
  • 3,330
  • 7
  • 25
  • 36