I ended up going for the following function that recursively "unwraps" layered Struct's:
Essentially, it keeps digging into Struct
fields and leave the other fields intact, and this approach eliminates the need to have a very long df.select(...)
statement when the Struct
has a lot of fields. Here's the code:
# Takes in a StructType schema object and return a column selector that flattens the Struct
def flatten_struct(schema, prefix=""):
result = []
for elem in schema:
if isinstance(elem.dataType, StructType):
result += flatten_struct(elem.dataType, prefix + elem.name + ".")
else:
result.append(col(prefix + elem.name).alias(prefix + elem.name))
return result
df = sc.parallelize([Row(r=Row(a=1, b=Row(foo="b", bar="12")))]).toDF()
df.show()
+----------+
| r|
+----------+
|[1,[12,b]]|
+----------+
df_expanded = df.select("r.*")
df_flattened = df_expanded.select(flatten_struct(df_expanded.schema))
df_flattened.show()
+---+-----+-----+
| a|b.bar|b.foo|
+---+-----+-----+
| 1| 12| b|
+---+-----+-----+