0

I would like to collapse the rows in a dataframe based on an ID column and count the number of records per ID using window functions. Doing this, I would like to avoid partitioning the window by ID, because this would result in a very large number of partitions.

I have a dataframe of the form

+----+-----------+-----------+-----------+
| ID | timestamp | metadata1 | metadata2 |
+----+-----------+-----------+-----------+
|  1 | 09:00     | ABC       | apple     |
|  1 | 08:00     | NULL      | NULL      |
|  1 | 18:00     | XYZ       | apple     |
|  2 | 07:00     | NULL      | banana    |
|  5 | 23:00     | ABC       | cherry    |
+----+-----------+-----------+-----------+

where I would like to keep only the records with the most recent timestamp per ID, such that I have

+----+-----------+-----------+-----------+-------+
| ID | timestamp | metadata1 | metadata2 | count |
+----+-----------+-----------+-----------+-------+
|  1 | 18:00     | XYZ       | apple     |     3 |
|  2 | 07:00     | NULL      | banana    |     1 |
|  5 | 23:00     | ABC       | cherry    |     1 |
+----+-----------+-----------+-----------+-------+

I have tried:

window = Window.orderBy( [asc('ID'), desc('timestamp')] )
window_count = Window.orderBy( [asc('ID'), desc('timestamp')] ).rowsBetween(-sys.maxsize,sys.maxsize)

columns_metadata = [metadata1, metadata2]

df = df.select(
              *(first(col_name, ignorenulls=True).over(window).alias(col_name) for col_name in columns_metadata),
              count(col('ID')).over(window_count).alias('count')
              )
df = df.withColumn("row_tmp", row_number().over(window)).filter(col('row_tmp') == 1).drop(col('row_tmp'))

which is in part based on How to select the first row of each group?

This without the use of pyspark.sql.Window.partitionBy, this does not give the desired output.

Wasserwaage
  • 208
  • 1
  • 3
  • 10

1 Answers1

0

I read you wanted without partitioning by ID after I posted it. I could only think of this approach.

Your dataframe:

df = sqlContext.createDataFrame(
  [
     ('1', '09:00', 'ABC', 'apple')
    ,('1', '08:00', '', '')
    ,('1', '18:00', 'XYZ', 'apple')
    ,('2', '07:00', '', 'banana')
    ,('5', '23:00', 'ABC', 'cherry')
  ]
  ,['ID', 'timestamp', 'metadata1', 'metadata2']
)

We can use rank and partition by ID over timestamp:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

w1 = Window().partitionBy(df['ID']).orderBy(df['timestamp']).orderBy(F.desc('timestamp'))
w2 = Window().partitionBy(df['ID'])

df\
  .withColumn("rank", F.rank().over(w1))\
  .withColumn("count", F.count('ID').over(w2))\
  .filter(F.col('rank') == 1)\
  .select('ID', 'timestamp', 'metadata1', 'metadata2', 'count')\
  .show()

+---+---------+---------+---------+-----+
| ID|timestamp|metadata1|metadata2|count|
+---+---------+---------+---------+-----+
|  1|    18:00|      XYZ|    apple|    3|
|  2|    07:00|         |   banana|    1|
|  5|    23:00|      ABC|   cherry|    1|
+---+---------+---------+---------+-----+
Luiz Viola
  • 2,143
  • 1
  • 11
  • 30