0

I'm struggling to get my hot deck imputation to work using the PySpark syntax. On a Pandas dataframe this would be trivial, but my data is stored in a PySpark dataframe.

What I got after quite some time of researching is this code:

from pyspark.sql.window import Window
from pyspark.sql.functions import when, lag

def impute_hot_deck(df, col, ref_col):
    window = Window.orderBy(ref_col)
    df = df.withColumn(col, when(df[col] == 'null', lag(col).over(window)).otherwise(df[col]))
    return df

Assumming df is a PySpark dataframe, col is the column to impute and ref_col is the column to sort by. Every example I found and also the PySpark documentation would suggest that this code should replace all 'null' values with the value found in the row above, but it simply doesn't do anything.

I am trying to impute the value bmi and I am sorting by age. Running this function on my data, i.e. df_imputed = impute_hot_deck(df, "bmi", "age") doesn't cause an error. The missing values are just not being replaced.

Using df_imputed.sort("age").show() reveals this sorted dataframe:

id    | gender | age | avg_glucose_level | bmi
47876 | Male   | 1   | 89.3              | 21.4
57372 | Male   | 1   | 123.21            | 15.1
12687 | Male   | 1   | 101.31            | 18.3
46035 | Male   | 1   | 84.85             | 20.3
54985 | Female | 1   | 199.83            | 24.5
54505 | Female | 1   | 70.81             | 19
61279 | Female | 1   | 123.76            | 21.4
49464 | Female | 1   | 64.13             | 20.1
43070 | Female | 1   | 118.22            | 20.2
21480 | Female | 1   | 91.03             | 19
70994 | Male   | 1   | 134.76            | 16.8
7127  | Female | 1   | 105.89            | 19
40270 | Female | 1   | 57.46             | 20.4
25954 | Female | 1   | 104.64            | 20.4
70895 | Male   | 1   | 87.27             | 18
45861 | Female | 1   | 91.5              | 18.6
31615 | Female | 1   | 91.35             | 13.6
16946 | Male   | 1   | 78.53             | 20.3
8702  | Female | 1   | 83.57             | 18
44091 | Female | 1   | 74.34             | 18
47930 | Female | 1   | 118.06            | 15.9
45860 | Female | 1   | 128.18            | 16.8
2174  | Female | 1   | 96.24             | null
...

Therefore the problem isn't that all the missing values are from the youngest persons, which my first idea was.

My guess would be that I'm not using lag and/or window correctly?

Any help would be much appreciated!

  • instead of df[col] == 'null' try df[col].isNull() – Anna K. Nov 19 '21 at 20:37
  • That worked, thanks! Although not all values are replaced. When two nulls follow each other only the first one is replaced. Any idea how to fix this? (other than looping until no values removed anymore) – Togepitsch Nov 20 '21 at 00:03
  • 1
    This solution seems like a smarter way to do this: https://stackoverflow.com/questions/36019847/pyspark-forward-fill-with-last-observation-for-a-dataframe – Togepitsch Nov 20 '21 at 10:10

0 Answers0