I needed a generic solution that can handle arbitrary level of nested column casting. By extending the accepted answer, I came up with the following functions
from typing import Dict
from pyspark.sql.types import StructType, ArrayType, StringType, StructField, _all_atomic_types
from pyspark.sql.functions import col
def apply_nested_column_casts(
schema: StructType, column_cast: Dict[str, str], parent: str
) -> StructType:
new_schema = []
if isinstance(schema, StringType):
return schema
for field in schema:
full_field_name = field.name
if parent:
full_field_name = parent + "." + full_field_name
if full_field_name not in column_cast:
if isinstance(field.dataType, StructType):
inner_schema = apply_nested_column_casts(
field.dataType, column_cast, full_field_name
)
new_schema.append(StructField(field.name, inner_schema))
elif isinstance(field.dataType, ArrayType):
inner_schema = apply_nested_column_casts(
field.dataType.elementType, column_cast, full_field_name
)
new_schema.append(StructField(field.name, ArrayType(inner_schema)))
else:
new_schema.append(StructField(field.name, field.dataType))
else:
# Here we change the field type to the intended type
cast_type = _all_atomic_types[column_cast[full_field_name]]
new_schema.append(StructField(field.name, cast_type()))
return StructType(new_schema)
def apply_column_casts(
df: SparkDataFrame, column_casts: Dict[str, str]
) -> SparkDataFrame:
for col_name, cast_to in column_casts.items():
splitted_col_name = col_name.split(".")
if len(splitted_col_name) == 1:
df = df.withColumn(col_name, col(col_name).cast(cast_to))
else:
nested_field_parent_field = splitted_col_name[0]
nested_field_parent_type = df.schema[nested_field_parent_field].dataType
column_cast = {col_name: cast_to}
if isinstance(nested_field_parent_type, StructType):
new_schema = apply_nested_column_casts(
nested_field_parent_type, column_cast, nested_field_parent_field
)
elif isinstance(nested_field_parent_type, ArrayType):
new_schema = ArrayType(
apply_nested_column_casts(
nested_field_parent_type.elementType,
column_cast,
nested_field_parent_field,
)
)
tmp_json = f"{nested_field_parent_field}_json"
df = df.withColumn(tmp_json, to_json(nested_field_parent_field)).drop(
nested_field_parent_field
)
df = df.withColumn(
nested_field_parent_field, from_json(tmp_json, new_schema)
).drop(tmp_json)
return df
And you can call the functions as shown below using dot notation for nested column casts
column_casts = {
"col_a": "string",
"col_b.nested_col": "double",
"col_b.nested_struct_col.some_col": "long",
}
df = apply_column_casts(df, column_casts)