Here's a way to do it without a udf.
Initialize example dataframe:
nested_df1 = (spark.read.json(sc.parallelize(["""[
{ "state": {"fld": 1} },
{ "state": {"fld": 2}}
]"""])))
nested_df1.printSchema()
root
|-- state: struct (nullable = true)
| |-- fld: long (nullable = true)
Spark .read.json
imports all integers as long
by default.
If state.fld
has to be an int
, you will need to cast it.
from pyspark.sql import functions as F
nested_df1 = (nested_df1
.select( F.struct(F.col("state.fld").alias("fld").cast('int')).alias("state") ))
nested_df1.printSchema()
root
|-- state: struct (nullable = false)
| |-- col1: integer (nullable = true)
nested_df1.show()
+-----+
|state|
+-----+
| [1]|
| [2]|
+-----+
Finally
Use .select
to get the nested columns you want from the existing struct with the "parent.child"
notation, create the new column, then re-wrap the old columns together with the new columns in a struct
.
val_a = 3
nested_df2 = (nested_df
.select(
F.struct(
F.col("state.fld"),
F.lit(val_a).alias("a")
).alias("state")
)
)
nested_df2.printSchema()
root
|-- state: struct (nullable = false)
| |-- fld: integer (nullable = true)
| |-- a: integer (nullable = false)
nested_df2.show()
+------+
| state|
+------+
|[1, 3]|
|[2, 3]|
+------+
Flatten if needed with "parent.*"
.
nested_df2.select("state.*").printSchema()
root
|-- fld: integer (nullable = true)
|-- a: integer (nullable = false)
nested_df2.select("state.*").show()
+---+---+
|fld| a|
+---+---+
| 1| 3|
| 2| 3|
+---+---+