12

I'm trying to expand a DataFrame column with nested struct type (see below) to multiple columns. The Struct schema I'm working with looks something like {"foo": 3, "bar": {"baz": 2}}.

Ideally, I'd like to expand the above into two columns ("foo" and "bar.baz"). However, when I tried using .select("data.*") (where data is the Struct column), I only get columns foo and bar, where bar is still a struct.

Is there a way such that I can expand the Struct for both layers?

zero323
  • 322,348
  • 103
  • 959
  • 935
Zz'Rot
  • 824
  • 1
  • 7
  • 24

2 Answers2

27

You can select data.bar.baz as bar.baz:

df.show()
+-------+
|   data|
+-------+
|[3,[2]]|
+-------+

df.printSchema()
root
 |-- data: struct (nullable = false)
 |    |-- foo: long (nullable = true)
 |    |-- bar: struct (nullable = false)
 |    |    |-- baz: long (nullable = true)

In pyspark:

import pyspark.sql.functions as F
df.select(F.col("data.foo").alias("foo"), F.col("data.bar.baz").alias("bar.baz")).show()
+---+-------+
|foo|bar.baz|
+---+-------+
|  3|      2|
+---+-------+
Psidom
  • 209,562
  • 33
  • 339
  • 356
  • Thank you for the response! As a follow-up, is there a way to do wildcard on `data.bar.*`, such that everything in `bar` gets automatically expanded to `bar.*` columns? – Zz'Rot Oct 24 '17 at 15:03
  • Not sure if you can keep the prefix, but you can do select `data.bar.*` which expands `bar.baz` as `baz` in this case. – Psidom Oct 24 '17 at 15:09
  • Thanks! I just posted the approach I'm settling on in my answer above -- it programmatically generates the list of `F.col(..).alias(..)` columns such that the code is 'cleaner', but it's based off of your answer here. – Zz'Rot Oct 25 '17 at 22:05
21

I ended up going for the following function that recursively "unwraps" layered Struct's:

Essentially, it keeps digging into Struct fields and leave the other fields intact, and this approach eliminates the need to have a very long df.select(...) statement when the Struct has a lot of fields. Here's the code:

# Takes in a StructType schema object and return a column selector that flattens the Struct
def flatten_struct(schema, prefix=""):
    result = []
    for elem in schema:
        if isinstance(elem.dataType, StructType):
            result += flatten_struct(elem.dataType, prefix + elem.name + ".")
        else:
            result.append(col(prefix + elem.name).alias(prefix + elem.name))
    return result


df = sc.parallelize([Row(r=Row(a=1, b=Row(foo="b", bar="12")))]).toDF()
df.show()
+----------+
|         r|
+----------+
|[1,[12,b]]|
+----------+

df_expanded = df.select("r.*")
df_flattened = df_expanded.select(flatten_struct(df_expanded.schema))

df_flattened.show()
+---+-----+-----+
|  a|b.bar|b.foo|
+---+-----+-----+
|  1|   12|    b|
+---+-----+-----+
Zz'Rot
  • 824
  • 1
  • 7
  • 24