2

As you may know, a DataFrame can contain fields which are complex types, like structures (StructType) or arrays (ArrayType). You may need, as in my case, to map all the DataFrame data to a Hive table, with simple type fields (String, Integer...). I've been struggling with this issue for a long time, and I've finally found a solution I want to share. Also, I'm sure it could be improved, so feel free to reply with your own suggestions.

It's based on this thread, but also works for ArrayType elements, not only StructType ones. It is a tail recursive function which receives a DataFrame, and returns it flattened.

def flattenDf(df: DataFrame): DataFrame = {
  var end = false
  var i = 0
  val fields = df.schema.fields
  val fieldNames = fields.map(f => f.name)
  val fieldsNumber = fields.length

  while (!end) {
    val field = fields(i)
    val fieldName = field.name

    field.dataType match {
      case st: StructType =>
        val childFieldNames = st.fieldNames.map(n => fieldName + "." + n)
        val newFieldNames = fieldNames.filter(_ != fieldName) ++ childFieldNames
        val newDf = df.selectExpr(newFieldNames: _*)
        return flattenDf(newDf)
      case at: ArrayType =>
        val fieldNamesExcludingArray = fieldNames.filter(_ != fieldName)
        val fieldNamesAndExplode = fieldNamesExcludingArray ++ Array(s"explode($fieldName) as a")
        val fieldNamesToSelect = fieldNamesExcludingArray ++ Array("a.*")
        val explodedDf = df.selectExpr(fieldNamesAndExplode: _*)
        val explodedAndSelectedDf = explodedDf.selectExpr(fieldNamesToSelect: _*)
        return flattenDf(explodedAndSelectedDf)
      case _ => Unit
    }

    i += 1
    end = i >= fieldsNumber
  }
  df
}
Torcuete
  • 21
  • 3

1 Answers1

0

val df = Seq(("1", (2, (3, 4)),Seq(1,2))).toDF()

df.printSchema

root
 |-- _1: string (nullable = true)
 |-- _2: struct (nullable = true)
 |    |-- _1: integer (nullable = false)
 |    |-- _2: struct (nullable = true)
 |    |    |-- _1: integer (nullable = false)
 |    |    |-- _2: integer (nullable = false)
 |-- _3: array (nullable = true)
 |    |-- element: integer (containsNull = false)


def flattenSchema(schema: StructType, fieldName: String = null) : Array[Column] = {
   schema.fields.flatMap(f => {
     val cols = if (fieldName == null) f.name else (fieldName + "." + f.name)
     f.dataType match {
       case structType: StructType => fattenSchema(structType, cols)
       case arrayType: ArrayType => Array(explode(col(cols)))
       case _ => Array(col(cols))
     }
   })
 }

df.select(flattenSchema(df.schema) :_*).printSchema

root
 |-- _1: string (nullable = true)
 |-- _1: integer (nullable = true)
 |-- _1: integer (nullable = true)
 |-- _2: integer (nullable = true)
 |-- col: integer (nullable = false)
silentshadow
  • 63
  • 1
  • 2
  • 9