36

Is there a way to replace null values in pyspark dataframe with the last valid value? There is addtional timestamp and session columns if you think you need them for windows partitioning and ordering. More specifically, I'd like to achieve the following conversion:

+---------+-----------+-----------+      +---------+-----------+-----------+
| session | timestamp |         id|      | session | timestamp |         id|
+---------+-----------+-----------+      +---------+-----------+-----------+
|        1|          1|       null|      |        1|          1|       null|
|        1|          2|        109|      |        1|          2|        109|
|        1|          3|       null|      |        1|          3|        109|
|        1|          4|       null|      |        1|          4|        109|
|        1|          5|        109| =>   |        1|          5|        109|
|        1|          6|       null|      |        1|          6|        109|
|        1|          7|        110|      |        1|          7|        110|
|        1|          8|       null|      |        1|          8|        110|
|        1|          9|       null|      |        1|          9|        110|
|        1|         10|       null|      |        1|         10|        110|
+---------+-----------+-----------+      +---------+-----------+-----------+
Oleksiy
  • 6,337
  • 5
  • 41
  • 58

4 Answers4

37

This uses last and ignores nulls.

Let's re-create something similar to the original data:

import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func

d = [{'session': 1, 'ts': 1}, {'session': 1, 'ts': 2, 'id': 109}, {'session': 1, 'ts': 3}, {'session': 1, 'ts': 4, 'id': 110}, {'session': 1, 'ts': 5},  {'session': 1, 'ts': 6}]
df = spark.createDataFrame(d)

df.show()
# +-------+---+----+
# |session| ts|  id|
# +-------+---+----+
# |      1|  1|null|
# |      1|  2| 109|
# |      1|  3|null|
# |      1|  4| 110|
# |      1|  5|null|
# |      1|  6|null|
# +-------+---+----+

Now, let's use window function last:

df.withColumn("id", func.last('id', True).over(Window.partitionBy('session').orderBy('ts').rowsBetween(-sys.maxsize, 0))).show()

# +-------+---+----+
# |session| ts|  id|
# +-------+---+----+
# |      1|  1|null|
# |      1|  2| 109|
# |      1|  3| 109|
# |      1|  4| 110|
# |      1|  5| 110|
# |      1|  6| 110|
# +-------+---+----+
ZygD
  • 22,092
  • 39
  • 79
  • 102
elmosca
  • 384
  • 1
  • 3
  • 5
  • 5
    A word of caution: This answer will collect all rows of each session to some executor node. This will result in failed jobs if the number of rows in some session is larger than the memory of your executor nodes. – Jordan P Sep 19 '17 at 22:39
  • 2
    `.rowsBetween(-sys.maxsize, 0)` may be removed – ZygD Sep 19 '22 at 18:58
  • F.first('id',True) can also be used by using " Window.partitionBy('session').orderBy(df.ts.desc())" – Hassaan Anwar Feb 21 '23 at 11:05
13

This seems to be doing the trick using Window functions:

import sys
from pyspark.sql.window import Window
import pyspark.sql.functions as func

def fill_nulls(df):
    df_na = df.na.fill(-1)
    lag = df_na.withColumn('id_lag', func.lag('id', default=-1)\
                           .over(Window.partitionBy('session')\
                                 .orderBy('timestamp')))

    switch = lag.withColumn('id_change',
                            ((lag['id'] != lag['id_lag']) &
                             (lag['id'] != -1)).cast('integer'))


    switch_sess = switch.withColumn(
        'sub_session',
        func.sum("id_change")
        .over(
            Window.partitionBy("session")
            .orderBy("timestamp")
            .rowsBetween(-sys.maxsize, 0))
    )

    fid = switch_sess.withColumn('nn_id',
                           func.first('id')\
                           .over(Window.partitionBy('session', 'sub_session')\
                                 .orderBy('timestamp')))

    fid_na = fid.replace(-1, 'null')

    ff = fid_na.drop('id').drop('id_lag')\
                          .drop('id_change')\
                          .drop('sub_session').\
                          withColumnRenamed('nn_id', 'id')

    return ff

Here is the full null_test.py.

Oleksiy
  • 6,337
  • 5
  • 41
  • 58
  • 1
    I was writing my test ! Thanks. The answer seems very clean to me. Having sessions is essential which made it possible to partition by and thus use the window function ! – eliasah Apr 04 '16 at 15:57
  • Nice solution! I'm surprised this feature doesn't exist in Spark yet. – Will Vousden Jun 15 '17 at 12:51
10

@Oleksiy's answer is great, but didn't fully work for my requirements. Within a session, if multiple nulls are observed, all are filled with the first non-null for the session. I needed the last non-null value to propagate forward.

The following tweak worked for my use case:

def fill_forward(df, id_column, key_column, fill_column):

    # Fill null's with last *non null* value in the window
    ff = df.withColumn(
        'fill_fwd',
        func.last(fill_column, True) # True: fill with last non-null
        .over(
            Window.partitionBy(id_column)
            .orderBy(key_column)
            .rowsBetween(-sys.maxsize, 0))
        )

    # Drop the old column and rename the new column
    ff_out = ff.drop(fill_column).withColumnRenamed('fill_fwd', fill_column)

    return ff_out
brett
  • 101
  • 1
  • 4
0

Here is the trick I followed by converting pyspark dataframe into pandas dataframe and doing the operation as pandas has built-in function to fill null values with previously known good value. And changing it back to pyspark dataframe. Here is the code!!

d = [{'session': 1, 'ts': 1}, {'session': 1, 'ts': 2, 'id': 109}, {'session': 1, 'ts': 3}, {'session': 1, 'ts': 4, 'id': 110}, {'session': 1, 'ts': 5},  {'session': 1, 'ts': 6},{'session': 1, 'ts': 7, 'id': 110},{'session': 1, 'ts': 8},{'session': 1, 'ts': 9},{'session': 1, 'ts': 10}]\
dt = spark.createDataFrame(d)


import pandas as pd\
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
psdf= dt.select("*").toPandas()\
psdf["id"].fillna(method='ffill', inplace=True)\
dt= spark.createDataFrame(psdf)\
dt.show()
  • 1
    using toPandas() might not scale for a very large dataset. you might bump into driver memory issues. I did a collect and populated the rows but the approach failed when load tested on a larger dataset. the above marked as answer works perfect.. – sunil kancharlapalli Dec 01 '22 at 22:34