Edit: As discussed in comments, the issue for the original methods could be from count
using the filter or aggregate functions which triggers unnecessary data scans. Below we explode the arrays and do the aggregation(count) before creating the final array column:
from pyspark.sql.functions import collect_list, struct
df = spark.createDataFrame([(2,[1,2]), (2,[1,2]), (3,[1,2,3]), (3,[1,2])],['timestamp', 'vars'])
df.selectExpr("timestamp", "explode(vars) as var") \
.groupby('timestamp','var') \
.count() \
.groupby("timestamp") \
.agg(collect_list(struct("var","count")).alias("data")) \
.selectExpr(
"timestamp",
"transform(data, x -> x.var) as indices",
"transform(data, x -> x.count) as values"
).selectExpr(
"timestamp",
"transform(sequence(0, array_max(indices)), i -> IFNULL(values[array_position(indices,i)-1],0)) as new_vars"
).show(truncate=False)
+---------+------------+
|timestamp|new_vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
Where:
(1) we explode the array and do count() for each timestamp
+ var
(2) groupby timestamp
and create an array of structs containing two fields var
and count
(3) convert the array of structs into two arrays: indices and values (similar to what we define SparseVector)
(4) transform the sequence sequence(0, array_max(indices))
, for each i in the sequence, use array_position to find the index of i
in indices
array and then retrieve the value from values
array at the same position, see below:
IFNULL(values[array_position(indices,i)-1],0)
notice that the function array_position uses 1-based index and array indexing is 0-based, thus we have a -1
in the above expression.
Old methods:
(1) Use transform + filter/size
from pyspark.sql.functions import flatten, collect_list
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr(
"timestamp",
"transform(sequence(0, array_max(data)), x -> size(filter(data, y -> y = x))) as vars"
).show(truncate=False)
+---------+------------+
|timestamp|vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
(2) Use aggregate function:
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr("timestamp", """
aggregate(
data,
/* use an array as zero_value, size = array_max(data))+1 and all values are zero */
array_repeat(0, int(array_max(data))+1),
/* increment the ith value of the Array by 1 if i == y */
(acc, y) -> transform(acc, (x,i) -> IF(i=y, x+1, x))
) as vars
""").show(truncate=False)