2

I'm working with some deeply nested data in a PySpark dataframe. As I'm trying to flatten the structure into rows and columns I noticed that when I call withColumn if the row contains null in the source column then that row is dropped from my result dataframe. Instead I would like to find a way to retain the row and have null in the resulting column.

A sample dataframe to work with:

from pyspark.sql.functions import explode, first, col, monotonically_increasing_id
from pyspark.sql import Row

df = spark.createDataFrame([
  Row(dataCells=[Row(posx=0, posy=1, posz=.5, value=1.5, shape=[Row(_type='square', _len=1)]), 
                 Row(posx=1, posy=3, posz=.5, value=4.5, shape=[]), 
                 Row(posx=2, posy=5, posz=.5, value=7.5, shape=[Row(_type='circle', _len=.5)])
    ])
])

I also have a function I use to flatten structs:

def flatten_struct_cols(df):
    flat_cols = [column[0] for column in df.dtypes if 'struct' not in column[1][:6]]
    struct_columns = [column[0] for column in df.dtypes if 'struct' in column[1][:6]]

    df = df.select(flat_cols +
                   [col(sc + '.' + c).alias(sc + '_' + c)
                   for sc in struct_columns
                   for c in df.select(sc + '.*').columns])

    return df

And the schema looks like this:

df.printSchema()

root
 |-- dataCells: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- posx: long (nullable = true)
 |    |    |-- posy: long (nullable = true)
 |    |    |-- posz: double (nullable = true)
 |    |    |-- shape: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- _len: long (nullable = true)
 |    |    |    |    |-- _type: string (nullable = true)
 |    |    |-- value: double (nullable = true)

The starting dataframe:

df.show(3)

+--------------------+
|           dataCells|
+--------------------+
|[[0,1,0.5,Wrapped...|
+--------------------+

I start by exploding the array since I want to turn this array of struct with an array of struct into rows and columns. I then flatten the struct fields into new columns.

df = df.withColumn('dataCells', explode(col('dataCells')))
df = flatten_struct_cols(df)
df.show(3)

And my data looks like:

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|   [[1,square]]|            1.5|
|             1|             3|           0.5|             []|            4.5|
|             2|             5|           0.5|[[null,circle]]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

All is well and as expected until I try to explode the dataCells_shape column which has an empty/null value.

df = df.withColumn('dataCells_shape', explode(col('dataCells_shape')))
df.show(3)

Which drops the second row out of the dataframe:

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|     [1,square]|            1.5|
|             2|             5|           0.5|  [null,circle]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

Instead I would like to keep the row and retain the empty value for that column as well as all of the values in the other columns. I've tried creating a new column instead of overwriting the old when doing the .withColumn explode and get the same result either way.

I also tried creating a UDF that performs the explode function if the row is not empty/null, but I have ran into JVM errors handling null.

from pyspark.sql.functions import udf
from pyspark.sql.types import NullType, StructType

def explode_if_not_null(trow):
    if trow:
        return explode(trow)
    else:
        return NullType

func_udf = udf(explode_if_not_null, StructType())
df = df.withColumn('dataCells_shape_test', func_udf(df['dataCells_shape']))
df.show(3)

AttributeError: 'NoneType' object has no attribute '_jvm'

Can anybody suggest a way for me to explode or flatten ArrayType columns without losing rows when the column is null?

I am using PySpark 2.2.0

Edit:

Following the link provided as a possible dupe I tried to implement the suggested .isNotNull().otherwise() solution providing the struct schema to .otherwise but the row is still dropping out of the result set.

df.withColumn("dataCells_shape_test", explode(when(col("dataCells_shape").isNotNull(), col("dataCells_shape"))
                                              .otherwise(array(lit(None).cast(df.select(col("dataCells_shape").getItem(0))
                                                                                                              .dtypes[0][1])
                                                              )
                                                        )
                                             )
             ).show()

+--------------+--------------+--------------+---------------+---------------+--------------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|dataCells_shape_test|
+--------------+--------------+--------------+---------------+---------------+--------------------+
|             0|             1|           0.5|   [[1,square]]|            1.5|          [1,square]|
|             2|             5|           0.5|[[null,circle]]|            7.5|       [null,circle]|
+--------------+--------------+--------------+---------------+---------------+--------------------+
pault
  • 41,343
  • 15
  • 107
  • 149
Alexander
  • 1,577
  • 5
  • 21
  • 35
  • instead of using a udf can you try using spark's inbuilt `when`? it'll go something like, `df = df.withColumn('dataCells', when(col('dataCells').isNotNull),explode(col('dataCells')))` – Chitral Verma Oct 10 '18 at 19:28
  • I'll give that a try and report back. Thanks for the idea. I also just noticed in Spark 2.3 and higher there is an `explode_outer` which would probably do what I need, but I'm stuck on 2.2.x for now. – Alexander Oct 10 '18 at 19:48
  • 1
    Possible duplicate of [Spark sql how to explode without losing null values](https://stackoverflow.com/questions/39739072/spark-sql-how-to-explode-without-losing-null-values). Though that post is not for pyspark, the technique is [not language specific](https://stackoverflow.com/questions/39739072/spark-sql-how-to-explode-without-losing-null-values#comment73803221_39739218). – pault Oct 10 '18 at 19:50
  • I get the following error when trying the above and linked suggestions. `TypeError: condition should be a Column`. Written like `df.withColumn("dataCells", explode( when(col("dataCells").isNotNull, col("dataCells")) .otherwise(None)))` – Alexander Oct 10 '18 at 20:06
  • 1
    @Alexander you are missing the parentheses at the end of `isNotNull()` – pault Oct 10 '18 at 20:11
  • 1
    @Alexander I can't test this, but [`explode_outer`](http://spark.apache.org/docs/2.2.0/api/java/org/apache/spark/sql/functions.html#explode_outer-org.apache.spark.sql.Column-) is a part of spark version 2.2 (but not available in pyspark until 2.3)- can you try the following: 1) `explode_outer = sc._jvm.org.apache.spark.sql.functions.explode_outer` and then `df.withColumn("dataCells", explode_outer("dataCells")).show()` or 2) `df.createOrReplaceTempView("myTable")` and then `spark.sql("select *, explode_outer(dataCells) from myTable").show()` – pault Oct 10 '18 at 20:14
  • Thanks @pault I'm not able to run this right now, but I'll let you know how it goes when I sit down and make the update. I like the idea of pulling in the explode_outer definition, but I'll give both a try just to have options for the future. – Alexander Oct 10 '18 at 21:06
  • 1
    @Alexander related post on how to pull in java/scala functions: [Spark: How to map Python with Scala or Java User Defined Functions?](https://stackoverflow.com/questions/33233737/spark-how-to-map-python-with-scala-or-java-user-defined-functions) – pault Oct 10 '18 at 21:17
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/181685/discussion-between-alexander-and-pault). – Alexander Oct 11 '18 at 15:07

2 Answers2

5

Thanks to pault for pointing me to this question and this question about mapping Python to Java. I was able to get a working solution with:

from pyspark.sql.column import Column, _to_java_column

def explode_outer(col):
    _explode_outer = sc._jvm.org.apache.spark.sql.functions.explode_outer 
    return Column(_explode_outer(_to_java_column(col)))

new_df = df.withColumn("dataCells_shape", explode_outer(col("dataCells_shape")))

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|     [1,square]|            1.5|
|             1|             3|           0.5|           null|            4.5|
|             2|             5|           0.5|  [null,circle]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

root
 |-- dataCells_posx: long (nullable = true)
 |-- dataCells_posy: long (nullable = true)
 |-- dataCells_posz: double (nullable = true)
 |-- dataCells_shape: struct (nullable = true)
 |    |-- _len: long (nullable = true)
 |    |-- _type: string (nullable = true)
 |-- dataCells_value: double (nullable = true)

It's important to note that this works for pyspark version 2.2 because explode_outer is defined in spark 2.2 (but for some reason the API wrapper was not implemented in pyspark until version 2.3). This solution creates a wrapper for the already implemented java function.

pault
  • 41,343
  • 15
  • 107
  • 149
Alexander
  • 1,577
  • 5
  • 21
  • 35
0

for that complex structure would be easier to write a map function and use it in flatMap method of RDD interface. As a result you will get a new flatted RDD, then you have to create a data frame again by applying a new schema.

def flat_arr(row):
    rows = []
    # apply some logic to fill rows list with more "rows"
    return rows

rdd = df.rdd.flatMap(flat_arr)
schema = StructType(
    StructField('field1', StringType()),
    # define more fields
)
df = df.sql_ctx.createDataFrame(rdd, schema)
df.show()

This solution looks a bit longer than applying withColumn, but it could be a first iteration of your solution so then you can see how to convert it to withColumn statements. But in my opinion map function would be appropriate here just to keep things clear

iurii_n
  • 1,330
  • 10
  • 17