3

I have a DataFrame with a single column which is an array of structs

df.printSchema()
root
 |-- dataCells: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- label: string (nullable = true)
 |    |    |-- value: string (nullable = true)

Some sample data might look like this:

df.first()
Row(dataCells=[Row(label="firstName", value="John"), Row(label="lastName", value="Doe"), Row(label="Date", value="1/29/2018")])

I'm trying to figure out how to reformat this DataFrame by turning each struct into a named column. I want to have a DataFrame like this:

------------------------------------
| firstName | lastName | Date      |
------------------------------------
| John      | Doe      | 1/29/2018 |
| ....      | ...      | ...       |

I've tried everything I can think of but haven't figured this out.

Burke
  • 3,359
  • 1
  • 16
  • 8

2 Answers2

6

Just explode and select *

from pyspark.sql.functions import explode, first, col, monotonically_increasing_id

df = spark.createDataFrame([
  Row(dataCells=[Row(label="firstName", value="John"), Row(label="lastName", value="Doe"), Row(label="Date", value="1/29/2018")])
])

long = (df
   .withColumn("id", monotonically_increasing_id())
   .select("id", explode("dataCells").alias("col"))
   .select("id", "col.*"))

and pivot:

long.groupBy("id").pivot("label").agg(first("value")).show()
# +-----------+---------+---------+--------+                                      
# |         id|     Date|firstName|lastName|
# +-----------+---------+---------+--------+
# |25769803776|1/29/2018|     John|     Doe|
# +-----------+---------+---------+--------+

You can also:

from pyspark.sql.functions import udf

@udf("map<string,string>")
def as_map(x):
    return dict(x)

cols = [col("dataCells")[c].alias(c) for c in ["Date", "firstName", "lastName"]]
df.select(as_map("dataCells").alias("dataCells")).select(cols).show()

# +---------+---------+--------+
# |     Date|firstName|lastName|
# +---------+---------+--------+
# |1/29/2018|     John|     Doe|
# +---------+---------+--------+

References:

Alper t. Turker
  • 34,230
  • 9
  • 83
  • 115
  • Great answer. I ran into the following error "The pivot column label has more than 10000 distinct values", which gave me some concerns about the performance of this approach in the long run. – Burke Jan 29 '18 at 23:15
  • It is a concern. Spark doesn't handle wide data well. In that case I'd recommend the second solution - both `explode` and `pivot` are on the expensive side. – Alper t. Turker Jan 29 '18 at 23:17
  • This works perfectly but with a slight caveat, if you have a record which is an empty array and you explode it, the row would be eliminated altogether, which might be a problem if you want to preserve empties. I suggest, using explode_outer instead and after pivoting, the result would have a null column, which you can subsequently drop. – Shivam Jul 12 '23 at 14:09
1

An alternate approach I tried without UDF,

>>> df.show()
+--------------------+
|           dataCells|
+--------------------+
|[[firstName,John]...|
+--------------------+

>>> from pyspark.sql import functions as F

## size of array with maximum length in column 
>>> arr_len = df.select(F.max(F.size('dataCells')).alias('len')).first().len

## get values from struct 
>>> df1 = df.select([df.dataCells[i].value for i in range(arr_len)])
>>> df1.show()
+------------------+------------------+------------------+
|dataCells[0].value|dataCells[1].value|dataCells[2].value|
+------------------+------------------+------------------+
|              John|               Doe|         1/29/2018|
+------------------+------------------+------------------+

>>> oldcols = df1.columns

## get the labels from struct
>>> cols = df.select([df.dataCells[i].label.alias('col_%s'%i) for i in range(arr_len)]).dropna().first()
>>> cols
Row(dataCells[0].label=u'firstName', dataCells[1].label=u'lastName', dataCells[2].label=u'Date')
>>> newcols = [cols[i] for i in range(arr_len)]
>>> newcols
[u'firstName', u'lastName', u'Date']

## use the labels to rename the columns
>>> df2 = reduce(lambda data, idx: data.withColumnRenamed(oldcols[idx], newcols[idx]), range(len(oldcols)), df1)
>>> df2.show()
+---------+--------+---------+
|firstName|lastName|     Date|
+---------+--------+---------+
|     John|     Doe|1/29/2018|
+---------+--------+---------+
Suresh
  • 5,678
  • 2
  • 24
  • 40