0

I have a dataframe like the following,

+--------------------+-----------------+--------------------+
|           column1  |     column2     |       column3      |
+--------------------+-----------------+--------------------+
|                  1 |             null|                null|
|                null|   A             |                  99|
|                null|             null|                null|
|                null|             null|                null|
|                null|   B             |                 100|
|                null|             null|                null|
|                null|             null|                null|
|                null|   C             |                 101|
|                null|             null|                null|
|                null|             null|                null|
+--------------------+-----------------+--------------------+

The following is what I expect,

+--------------------+-----------------+--------------------+
|           column1  |     column2     |       column3      |
+--------------------+-----------------+--------------------+
|                1   |             null|                null|
|                1   |         A       |                  99|
|                1   |         A       |                  99|
|                1   |         A       |                  99|
|                1   |         B       |                 100|
|                1   |         B       |                 100|
|                1   |         B       |                 100|
|                1   |         C       |                 101|
|                1   |         C       |                 101|
|                1   |         C       |                 101|
+--------------------+-----------------+--------------------+

I am new to PySpark and I am not sure how I can achieve this using PySpark functions.

samkart
  • 6,007
  • 2
  • 14
  • 29
royalewithcheese
  • 402
  • 4
  • 17

1 Answers1

0

You will need to use Window functions for this.

  1. import functions and create test data.
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import desc, last

spark = SparkSession.builder.getOrCreate()

columns = ["col1", "col2", "col3"]
data = [(1, None, None),
        (None, "A", 99),
        (None, None, None),
        (None, None, None),
        (None, "B", 100),
        (None, None, None),
        (None, None, None),
        (None, "C", 101),
        (None, None, None),
        (None, None, None)]


df = spark.createDataFrame(data).toDF(*columns)
  1. create a Window, in this example i do not specify a column for partitionBy, which will lead to the whole df to be loaded (not the best idea in production, I hope you have a suitable column to group by). In rowsBetween I specify, that all data between the current row and the first row are looked at.
window = Window.partitionBy().orderBy(desc("col1")).rowsBetween(Window.unboundedPreceding, 0)
  1. the last function takes the last value it sees, with nulls beeing ignored.
df = (
    df.withColumn("col1", last("col1", ignorenulls=True).over(window))
    .withColumn("col2", last("col2", True).over(window))
    .withColumn("col3", last("col3", True).over(window))
    )

df.show()

Which will result in:

+----+----+----+                                                                
|col1|col2|col3|
+----+----+----+
|   1|null|null|
|   1|   A|  99|
|   1|   A|  99|
|   1|   A|  99|
|   1|   B| 100|
|   1|   B| 100|
|   1|   B| 100|
|   1|   C| 101|
|   1|   C| 101|
|   1|   C| 101|
+----+----+----+

Reference:
Fill in null with previously known good value with pyspark https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Window.html https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.functions.last.html https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.Window.rowsBetween.html

kahobe
  • 36
  • 4