2

I would like to get the first and last row of each partition in spark (I'm using pyspark). How do I go about this? In my code I repartition my dataset based on a key column using:

mydf.repartition(keyColumn).sortWithinPartitions(sortKey)

Is there a way to get the first row and last row for each partition? Thanks

Java Developr
  • 59
  • 1
  • 3
  • 1
    Not sure why my question has been down-voted. Is there something wrong with the post? Would be really helpful if you can leave a comment about what you think is wrong with post if you down-vote – Java Developr Feb 20 '20 at 19:00
  • 1
    why you want first and last row of each partition ? you can use foreachPartition, it will give you a iterator – maxime G Feb 20 '20 at 19:19
  • you probably mean `mapPartitions` @maximeG `foreachPartition` will not allow you to modify the final output – abiratsis May 15 '20 at 09:53
  • OP didn't say that he wants to modify the output – maxime G May 16 '20 at 09:28
  • how can you extract first/last using `foreachPartition`? In my understanding the question is about extracting first/last items of each partition, i.e 4 par -> 8 items via the Spark API, no hidden storage or any 3rd party library – abiratsis May 17 '20 at 20:04

3 Answers3

4

I would highly advise against working with partitions directly. Spark does a lot of DAG optimisation, so when you try executing specific functionality on each partition, all your assumptions about the partitions and their distribution might be completely false.

You seem to however have a keyColumn and sortKey, so then I'd just suggest to do the following:

import pyspark
import pyspark.sql.functions as f

w_asc = pyspark.sql.Window.partitionBy(keyColumn).orderBy(f.asc(sortKey))
w_desc = pyspark.sql.Window.partitionBy(keyColumn).orderBy(f.desc(sortKey))
res_df = mydf. \
 withColumn("rn_asc", f.row_number().over(w_asc)). \
 withColumn("rn_desc", f.row_number().over(w_desc)). \
 where("rn_asc = 1 or rn_desc = 1")

The resulting dataframe will have 2 additional columns, where rn_asc=1 indicates the first row and rn_desc=1 indicates the last row.

Richard Nemeth
  • 1,784
  • 1
  • 6
  • 16
  • 3
    there is a good reason that Spark devs exposed the partitions through Spark API and the reason is to be able to implement cases similar to this one. We don't need to use window function here since it will introduce unnecessary overhead. Spark provides an iterator through the mapPartitions method precisely because working directly with iterators is very efficient. The orderBy or partitionBy will cause data shuffling and this is what we always want to avoid. If I understood correctly OP is asking not to touch the current partitions just to get first/last element from the existing ones. – abiratsis May 17 '20 at 20:33
  • 1
    Good point Alexandros :) And totally agreed. The reason I'm suggesting window function is because I don't believe the OP has the partitions in place (since they are repartitioning the input dataframe), so the reshuffling is necessary either way. – Richard Nemeth May 18 '20 at 06:05
0

Scala: I think the repartition is not by come key column but it requires the integer how may partition you want to set. I made a way to select the first and last row by using the Window function of the spark.

First, this is my test data.

+---+-----+
| id|value|
+---+-----+
|  1|    1|
|  1|    2|
|  1|    3|
|  1|    4|
|  2|    1|
|  2|    2|
|  2|    3|
|  3|    1|
|  3|    3|
|  3|    5|
+---+-----+

Then, I use the Window function twice, because I cannot know the last row easily but the reverse is quite easy.

import org.apache.spark.sql.expressions.Window
val a = Window.partitionBy("id").orderBy("value")
val d = Window.partitionBy("id").orderBy(col("value").desc)

val df = spark.read.option("header", "true").csv("test.csv")
df.withColumn("marker", when(rank.over(a) === 1, "Y").otherwise("N"))
  .withColumn("marker", when(rank.over(d) === 1, "Y").otherwise(col("marker")))
  .filter(col("marker") === "Y")
  .drop("marker").show

The final result is then,

+---+-----+
| id|value|
+---+-----+
|  3|    5|
|  3|    1|
|  1|    4|
|  1|    1|
|  2|    3|
|  2|    1|
+---+-----+
Lamanus
  • 12,898
  • 4
  • 21
  • 47
0

Here is another approach using mapPartitions from RDD API. We iterate over the elements of each partition until we reach the end. I would expect this iteration to be very fast since we skip all the elements of the partition except the two edges. Here is the code:

df = spark.createDataFrame([
  ["Tom", "a"],
  ["Dick", "b"],
  ["Harry", "c"],
  ["Elvis", "d"],
  ["Elton", "e"],
  ["Sandra", "f"]
], ["name", "toy"])

def get_first_last(it):
      first = last = next(it)
      for last in it:
        pass

      # Attention: if first equals last by reference return only one!
      if first is last:
        return [first]

      return [first, last]

# coalesce here is just for demonstration
first_last_rdd = df.coalesce(2).rdd.mapPartitions(get_first_last)

spark.createDataFrame(first_last_rdd, ["name", "toy"]).show()

# +------+---+
# |  name|toy|
# +------+---+
# |   Tom|  a|
# | Harry|  c|
# | Elvis|  d|
# |Sandra|  f|
# +------+---+

PS: Odd positions will contain the first partition element and the even ones the last item. Also note that the number of results will be (numPartitions * 2) - numPartitionsWithOneItem which I expect to be relatively small therefore you shouldn't bother about the cost of the new createDataFrame statement.

abiratsis
  • 7,051
  • 3
  • 28
  • 46