I'm working with some deeply nested data in a PySpark dataframe. As I'm trying to flatten the structure into rows and columns I noticed that when I call withColumn
if the row contains null
in the source column then that row is dropped from my result dataframe. Instead I would like to find a way to retain the row and have null
in the resulting column.
A sample dataframe to work with:
from pyspark.sql.functions import explode, first, col, monotonically_increasing_id
from pyspark.sql import Row
df = spark.createDataFrame([
Row(dataCells=[Row(posx=0, posy=1, posz=.5, value=1.5, shape=[Row(_type='square', _len=1)]),
Row(posx=1, posy=3, posz=.5, value=4.5, shape=[]),
Row(posx=2, posy=5, posz=.5, value=7.5, shape=[Row(_type='circle', _len=.5)])
])
])
I also have a function I use to flatten structs:
def flatten_struct_cols(df):
flat_cols = [column[0] for column in df.dtypes if 'struct' not in column[1][:6]]
struct_columns = [column[0] for column in df.dtypes if 'struct' in column[1][:6]]
df = df.select(flat_cols +
[col(sc + '.' + c).alias(sc + '_' + c)
for sc in struct_columns
for c in df.select(sc + '.*').columns])
return df
And the schema looks like this:
df.printSchema()
root
|-- dataCells: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- posx: long (nullable = true)
| | |-- posy: long (nullable = true)
| | |-- posz: double (nullable = true)
| | |-- shape: array (nullable = true)
| | | |-- element: struct (containsNull = true)
| | | | |-- _len: long (nullable = true)
| | | | |-- _type: string (nullable = true)
| | |-- value: double (nullable = true)
The starting dataframe:
df.show(3)
+--------------------+
| dataCells|
+--------------------+
|[[0,1,0.5,Wrapped...|
+--------------------+
I start by exploding the array since I want to turn this array of struct with an array of struct into rows and columns. I then flatten the struct fields into new columns.
df = df.withColumn('dataCells', explode(col('dataCells')))
df = flatten_struct_cols(df)
df.show(3)
And my data looks like:
+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
| 0| 1| 0.5| [[1,square]]| 1.5|
| 1| 3| 0.5| []| 4.5|
| 2| 5| 0.5|[[null,circle]]| 7.5|
+--------------+--------------+--------------+---------------+---------------+
All is well and as expected until I try to explode
the dataCells_shape
column which has an empty/null value.
df = df.withColumn('dataCells_shape', explode(col('dataCells_shape')))
df.show(3)
Which drops the second row out of the dataframe:
+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
| 0| 1| 0.5| [1,square]| 1.5|
| 2| 5| 0.5| [null,circle]| 7.5|
+--------------+--------------+--------------+---------------+---------------+
Instead I would like to keep the row and retain the empty value for that column as well as all of the values in the other columns. I've tried creating a new column instead of overwriting the old when doing the .withColumn
explode
and get the same result either way.
I also tried creating a UDF
that performs the explode
function if the row is not empty/null, but I have ran into JVM errors handling null
.
from pyspark.sql.functions import udf
from pyspark.sql.types import NullType, StructType
def explode_if_not_null(trow):
if trow:
return explode(trow)
else:
return NullType
func_udf = udf(explode_if_not_null, StructType())
df = df.withColumn('dataCells_shape_test', func_udf(df['dataCells_shape']))
df.show(3)
AttributeError: 'NoneType' object has no attribute '_jvm'
Can anybody suggest a way for me to explode or flatten ArrayType
columns without losing rows when the column is null
?
I am using PySpark 2.2.0
Edit:
Following the link provided as a possible dupe I tried to implement the suggested .isNotNull().otherwise()
solution providing the struct schema to .otherwise
but the row is still dropping out of the result set.
df.withColumn("dataCells_shape_test", explode(when(col("dataCells_shape").isNotNull(), col("dataCells_shape"))
.otherwise(array(lit(None).cast(df.select(col("dataCells_shape").getItem(0))
.dtypes[0][1])
)
)
)
).show()
+--------------+--------------+--------------+---------------+---------------+--------------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|dataCells_shape_test|
+--------------+--------------+--------------+---------------+---------------+--------------------+
| 0| 1| 0.5| [[1,square]]| 1.5| [1,square]|
| 2| 5| 0.5|[[null,circle]]| 7.5| [null,circle]|
+--------------+--------------+--------------+---------------+---------------+--------------------+