With arrays of unequal size in the different groups (for you, a group is ("contigName", "start", "end", "referenceAllele")
, which I'll simply rename to group
), you could consider exploding the array column (the alleleFrequencies
), with introduction of the position the values had within the arrays. That will give you an additional column you can use in grouping to compute the average you had in mind. At this point you might actually have enough for further computations (see df3.show()
below).
If you really must have it back into an array, that's harder and I haven't an idea. One must keep track of the order, and I believe that's easy with a map (a dictionary, if you like). To do so, I use the aggregation function collect_list
on two columns. While collect_list
isn't deterministic (you don't know the order in which values will be returned in the list, because rows are shuffled), the aggregation over both arrays will preserve their order, as the rows get shuffled in their entirety (see df4.show()
, below). From there, you can create a mapping of the position to the average with map_from_arrays
.
Example:
>>> from pyspark.sql.functions import mean, col, posexplode, collect_list, map_from_arrays
>>>
>>> df = spark.createDataFrame([
... ("A", [0, 1, 2]),
... ("A", [0, 3, 6]),
... ("B", [1, 2, 4, 5]),
... ("B", [1, 2, 6, 1])],
... schema=("group", "values"))
>>> df2 = df.select(df.group, posexplode(df.values)) # adds the "pos" and "col" columns
>>> df3 = (df2
... .groupBy("group", "pos")
... .agg(mean(col("col")).alias("avg_of_positions"))
... )
>>> df4 = (df3
... .groupBy("group")
... .agg(
... collect_list("pos").alias("pos"),
... collect_list("avg_of_positions").alias("avgs")
... )
... )
>>> df5 = df4.select(
... "group",
... map_from_arrays(col("pos"), col("avgs")).alias("positional_averages")
... )
>>> df5.show(truncate=False)
[Stage 0:> (0 + 4) / 4]
+-----+----------------------------------------+
|group|positional_averages |
+-----+----------------------------------------+
|B |[0 -> 1.0, 1 -> 2.0, 3 -> 3.0, 2 -> 5.0]|
|A |[0 -> 0.0, 1 -> 2.0, 2 -> 4.0] |
+-----+----------------------------------------+