5

I am trying to fill NaN values with mean using PySpark. Below is my code that I am using and following is the error that occurred:

from pyspark.sql.functions import avg


def fill_with_mean(df_1, exclude=set()):
    stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
    return df_1.na.fill(stats.first().asDict())

res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
res.show()

Error:

Py4JJavaError Traceback (most recent call last)
  <ipython-input-35-42f4d984f022> in <module>()
  3   stats = df_1.agg(*(avg(c).alias(c) for c in df_1.columns if c not in exclude))
  4   return df_1.na.fill(stats.first().asDict())
   ----> 5 res = fill_with_mean(df_1, ["MinTemp", "MaxTemp", "Evaporation", "Sunshine"])
  6 res.show()



  5 frames
  /usr/local/lib/python3.7/dist-packages/py4j/protocol.py in get_return_value(answer, 
  gateway_client, target_id, name)
  326                 raise Py4JJavaError(
  327                     "An error occurred while calling {0}{1}{2}.\n".
  --> 328                     format(target_id, ".", name), value)
  329             else:
  330                 raise Py4JError(

  Py4JJavaError: An error occurred while calling o376.fill.
  : java.lang.NullPointerException
at org.apache.spark.sql.DataFrameNaFunctions.$anonfun$fillMap$1(DataFrameNaFunctions.scala:418)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.DataFrameNaFunctions.fillMap(DataFrameNaFunctions.scala:407)
at org.apache.spark.sql.DataFrameNaFunctions.fill(DataFrameNaFunctions.scala:232)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)

Can you let me know where am I going wrong? Is there any alternative way to fill missing values using mean?

This is how my dataframe looks like: enter image description here

I wish to see mean values filled in place of null. Also, Evaporation and sunshine are not completely null, there are other values in it too.

The dataset is a csv file:

from pyspark.sql.functions import *
import pyspark
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","
df_1= spark.read.format("csv").option("header","true").load('/content/weatherAUS.csv')
df_1.show()

Source: https://www.kaggle.com/jsphyg/weather-dataset-rattle-package

Mykola Zotko
  • 15,583
  • 3
  • 71
  • 73
John
  • 279
  • 1
  • 3
  • 16
  • Hi Steven, the dataset is a csv file. Updated the code and mentioned its source – John Nov 14 '21 at 17:28
  • 1
    i guess you want to replace the "none" value of the numeric columns. Basically, `location` should not be replaced even if you did not include it in your `exclude` set, right ? – Steven Nov 14 '21 at 21:57
  • Ya pretty much! – John Nov 14 '21 at 22:49
  • you have 50 cities, with data over almost 10 year, and you simply want to replace missing value with an avg of the value over the whole dataframe ? is it really making any sense ? – Steven Nov 15 '21 at 09:09

2 Answers2

1

Based on your input data, I create my dataframe :

from pyspark.sql import functions as F, Window

df = spark.read.csv("./weatherAUS.csv", header=True, inferSchema=True, nullValue="NA")

Then, I process the whole dataframe, excluding the columns you mentionned + the columns that cannot be replaced (date and location)

exclude = ["date", "location"] + ["mintemp", "maxtemp", "evaporation", "sunshine"]


df2 = df.select(
    *(
        F.coalesce(F.col(col), F.avg(col).over(Window.orderBy(F.lit(1)))).alias(col)
        if col.lower() not in exclude
        else F.col(col)
        for col in df.columns
    )
)

df2.show(5)
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
|               Date|  Location|MinTemp|MaxTemp|Rainfall|Evaporation|Sunshine|WindGustDir|WindGustSpeed|WindDir9am|WindDir3pm|WindSpeed9am|WindSpeed3pm|Humidity9am|Humidity3pm|Pressure9am|Pressure3pm|Cloud9am|Cloud3pm|Temp9am|Temp3pm|RainToday|RainTomorrow|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
|2012-07-02 22:00:00|Townsville|   12.4|   23.3|     0.0|        6.0|    10.8|        SSW|         33.0|        SE|         S|         7.0|        20.0|       34.0|       28.0|     1019.5|     1015.5|     1.0|     2.0|   17.5|   23.0|       No|          No|
|2012-07-03 22:00:00|Townsville|    9.1|   21.7|     0.0|        5.0|    10.9|         SE|         39.0|       SSW|       SSE|        17.0|        20.0|       26.0|       14.0|     1021.7|     1018.4|     1.0|     0.0|   16.4|   21.2|       No|          No|
|2012-07-04 22:00:00|Townsville|    8.2|   23.4|     0.0|        5.2|    10.6|        SSW|         30.0|       SSW|        NE|        22.0|        13.0|       34.0|       40.0|     1021.7|     1018.5|     2.0|     2.0|   17.1|   22.3|       No|          No|
|2012-07-05 22:00:00|Townsville|   10.5|   24.5|     0.0|        6.0|    10.2|          E|         39.0|       SSW|        SE|        11.0|        17.0|       48.0|       31.0|     1021.2|     1017.2|     1.0|     2.0|   17.9|   23.8|       No|          No|
|2012-07-06 22:00:00|Townsville|   17.7|   24.1|     0.0|        6.8|     0.5|         SE|         54.0|        SE|       ESE|        19.0|        31.0|       69.0|       58.0|     1019.2|     1017.0|     8.0|     7.0|   20.1|   23.2|       No|          No|
+-------------------+----------+-------+-------+--------+-----------+--------+-----------+-------------+----------+----------+------------+------------+-----------+-----------+-----------+-----------+--------+--------+-------+-------+---------+------------+
only showing top 5 rows
Steven
  • 14,048
  • 6
  • 38
  • 73
  • Hi Steven, Thank you for your valuable input. I cannot see any column value changed using both of the above code. I have given a sample dataframe in the question for your reference. – John Nov 14 '21 at 01:25
  • Please keep in mind using withColumn in for loop can cause errors. You can use select for the same results. – Bibzon Nov 14 '21 at 12:48
  • @Bibzon, If I go with your approach, it returns me a dataframe with Yes/No values. – John Nov 14 '21 at 16:34
  • @John I updated my answer with a more compact version. I tested and it works fine but still, I think that replacing missing value wit an avg on the whole dataframe is absurd. What about an average for the city, between the X previous and/or following days ? – Steven Nov 15 '21 at 09:18
  • Yes Steven, I agree with you. I am not trying to replace all column values with mean, only few of them. I wanted to know how do we impute mean to the missing values. – John Nov 15 '21 at 13:36
1

You can use imputation estimator Imputer:

df = spark.createDataFrame([(1.0, float("nan")),
                            (2.0, float("nan")),
                            (float("nan"), 3.0),
                            (4.0, 4.0),
                            (5.0, 5.0)], ["a", "b"])
df.show()

+---+---+
|  a|  b|
+---+---+
|1.0|NaN|
|2.0|NaN|
|NaN|3.0|
|4.0|4.0|
|5.0|5.0|
+---+---+

import pyspark.ml.feature as MF

imputer = MF.Imputer(strategy='mean', inputCols=['a', 'b'], outputCols=['out_a', 'out_b'])
model = imputer.fit(df)
model.transform(df).show()

+---+---+-----+-----+
|  a|  b|out_a|out_b|
+---+---+-----+-----+
|1.0|NaN|  1.0|  4.0|
|2.0|NaN|  2.0|  4.0|
|NaN|3.0|  3.0|  3.0|
|4.0|4.0|  4.0|  4.0|
|5.0|5.0|  5.0|  5.0|
+---+---+-----+-----+

Using chaining method:

(Imputer().
 setStrategy('mean').
 setInputCols(['a', 'b']).
 setOutputCols(['out_a', 'out_b']).
 fit(df).
 transform(df).
 show())
Mykola Zotko
  • 15,583
  • 3
  • 71
  • 73