1

I am working on a dataframe that looks like this:

root
 |-- _id: string (nullable = true)
 |-- positions: struct (nullable = true)
 |    |-- precise: struct (nullable = true)
 |    |    |-- lat: double (nullable = true)
 |    |    |-- lng: double (nullable = true)
 |    |-- unprecise: struct (nullable = true)
 |    |    |-- lat: double (nullable = true)
 |    |    |-- lng: double (nullable = true)

The Struct objects in positions Struct can contain "precise", or "unprecise", or both, or several others Struct objects. So a row having precise and unprecise location should be exploded into two rows.

What would be the best way to explode such dataframe? Ideally I would like to have:

root
 |-- _id: string (nullable = true)
 |-- positions_type: string (nullable = true) // "precise" or "unprecise"
 |-- lat: double (nullable = true)
 |-- lng: double (nullable = true)

I have followed Exploding nested Struct in Spark dataframe it is about exploding a Struct column and not a nested Struct.

Another idea would be to flatten everything and have as many columns as nested struct object there are, but it is not really ideal as the schema will change if new struct objects is added.

Thanks in advance.

mich
  • 347
  • 3
  • 10

2 Answers2

2

You can try converting it to a map or list and then use explode to create the dataframe you need it. I have done it in scala. I hope its helpful.

/* Creating sample data */
case class Position(lat : Double, lng : Double)
case class Positions(precise : Position, unprecise : Position)

import spark.implicits._

val list = List(("0",Positions(Position(0.1, 0.2), Position(1.1, 1.2))),
  ("1",Positions(Position(0.1, 0.2), Position(1.1, 1.2))))
val df = list.toDF("_id", "positions")
df.printSchema()

val resDF = df.withColumn("positions_arr",
  array(
    struct(lit("precise").as("positions_type"), $"positions.precise.lat", $"positions.precise.lng"),
    struct(lit("unprecise").as("positions_type"), $"positions.unprecise.lat", $"positions.unprecise.lng")
  )
).withColumn("position", explode($"positions_arr"))
  .withColumn("positions_type", $"position.positions_type")
  .withColumn("lat", $"position.lat")
  .withColumn("lng", $"position.lng")
  .drop("positions","positions_arr","position")

resDF.show(false)
resDF.printSchema()

+---+--------------+---+---+ |_id|positions_type|lat|lng| +---+--------------+---+---+ |0 |precise |0.1|0.2| |0 |unprecise |1.1|1.2| |1 |precise |0.1|0.2| |1 |unprecise |1.1|1.2| +---+--------------+---+---+

Apurba Pandey
  • 1,061
  • 10
  • 21
  • Thanks, I ended using blackbishop's code snippet as it is in Pyspark. Both methods are similar. – mich Jan 16 '20 at 14:57
1

Simply use create_map and explode functions like this:

df = df.select("_id", explode(create_map(lit("precise"), col("positions.precise"),
                                         lit("unprecise"), col("positions.unprecise")
                                        )
                             ).alias("positions_type", "pos")
              ) \
      .select("_id", "positions_type", "pos.*") \

Result schema:

root
 |-- _id: string (nullable = true)
 |-- positions_type: string (nullable = false)
 |-- lat: string (nullable = true)
 |-- lng: string (nullable = true)
blackbishop
  • 30,945
  • 11
  • 55
  • 76