Using the method presented in the accepted answer I wrote a function to recursively unnest a dataframe (recursing into nested arrays as well):
from pyspark.sql.types import ArrayType, StructType
def flatten(df, sentinel="x"):
def _gen_flatten_expr(schema, indent, parents, last, transform=False):
def handle(field, last):
path = parents + (field.name,)
alias = (
" as "
+ "_".join(path[1:] if transform else path)
+ ("," if not last else "")
)
if isinstance(field.dataType, StructType):
yield from _gen_flatten_expr(
field.dataType, indent, path, last, transform
)
elif (
isinstance(field.dataType, ArrayType) and
isinstance(field.dataType.elementType, StructType)
):
yield indent, "transform("
yield indent + 1, ".".join(path) + ","
yield indent + 1, sentinel + " -> struct("
yield from _gen_flatten_expr(
field.dataType.elementType,
indent + 2,
(sentinel,),
True,
True
)
yield indent + 1, ")"
yield indent, ")" + alias
else:
yield (indent, ".".join(path) + alias)
try:
*fields, last_field = schema.fields
except ValueError:
pass
else:
for field in fields:
yield from handle(field, False)
yield from handle(last_field, last)
lines = []
for indent, line in _gen_flatten_expr(df.schema, 0, (), True):
spaces = " " * 4 * indent
lines.append(spaces + line)
expr = "struct(" + "\n".join(lines) + ") as " + sentinel
return df.selectExpr(expr).select(sentinel + ".*")