23

Input

I have a column Parameters of type map of the form:

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
df = sqlContext.createDataFrame(d)

df.collect()
# [Row(Parameters={'foo': '1', 'bar': '2', 'baz': 'aaa'})]

df.printSchema()
# root
#  |-- Parameters: map (nullable = true)
#  |    |-- key: string
#  |    |-- value: string (valueContainsNull = true)

Output

I want to reshape it in PySpark so that all the keys (foo, bar, etc.) would become columns, namely:

[Row(foo='1', bar='2', baz='aaa')]

Using withColumn works:

(df
 .withColumn('foo', df.Parameters['foo'])
 .withColumn('bar', df.Parameters['bar'])
 .withColumn('baz', df.Parameters['baz'])
 .drop('Parameters')
).collect()

But I need a solution that doesn't explicitly mention the column names, as I have dozens of them.

ZygD
  • 22,092
  • 39
  • 79
  • 102
Kamil Sindi
  • 21,782
  • 19
  • 96
  • 120

3 Answers3

31

Since keys of the MapType are not a part of the schema you'll have to collect these first for example like this:

from pyspark.sql.functions import explode

keys = (df
    .select(explode("Parameters"))
    .select("key")
    .distinct()
    .rdd.flatMap(lambda x: x)
    .collect())

When you have this all what is left is simple select:

from pyspark.sql.functions import col

exprs = [col("Parameters").getItem(k).alias(k) for k in keys]
df.select(*exprs)
zero323
  • 322,348
  • 103
  • 959
  • 935
  • Thanks! This works for me but with one exception. when I print schema for the data frame - df.select(*exprs), it returns all the data types to string. I have one of the data types which is of type struct within keys. How can i access that ? – TopCoder Sep 11 '17 at 20:47
  • 1
    @TopCoder [`topfield.nestedfield`](https://stackoverflow.com/questions/28332494/querying-spark-sql-dataframe-with-complex-types)? – zero323 Sep 14 '17 at 11:04
  • 2
    what happens if you have like 280 keys that you have to turn into columns? I keep getting the message that it exceeds the overhead memory of spark. – Maeror Mar 10 '21 at 07:54
  • this answer has helped me now multiple times - I keep coming back to it. thank you! – wylie Mar 24 '23 at 17:54
5

Performant solution

One of the question constraints is to dynamically determine the column names, which is fine, but be warned that this can be really slow. Here's how you can avoid typing and write code that'll execute quickly.

cols = list(map(
    lambda f: F.col("Parameters").getItem(f).alias(str(f)),
    ["foo", "bar", "baz"]))
df.select(cols).show()
+---+---+---+
|foo|bar|baz|
+---+---+---+
|  1|  2|aaa|
+---+---+---+

Notice that this runs a single select operation. Don't run withColumn multiple times because that's slower.

The fast solution is only possible if you know all the map keys. You'll need to revert to the slower solution if you don't know all the unique values for the map keys.

Slower solution

The accepted answer is good. My solution is a bit more performant because it doesn't call .rdd or flatMap().

import pyspark.sql.functions as F

d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
df = spark.createDataFrame(d)

keys_df = df.select(F.explode(F.map_keys(F.col("Parameters")))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(map(lambda f: F.col("Parameters").getItem(f).alias(str(f)), keys))
df.select(key_cols).show()
+---+---+---+
|bar|foo|baz|
+---+---+---+
|  2|  1|aaa|
+---+---+---+

Collecting results to the driver node can be a performance bottleneck. It's good to execute this code list(map(lambda row: row[0], keys_df.collect())) as a separate command to make sure it's not running too slowly.

Machavity
  • 30,841
  • 27
  • 92
  • 100
Powers
  • 18,150
  • 10
  • 103
  • 108
0

Performance-wise, not hard-coding column names, use this:

from pyspark.sql import functions as F

df = df.withColumn("_c", F.to_json("Parameters"))
json_schema = spark.read.json(df.rdd.map(lambda r: r._c)).schema
df = df.withColumn("_c", F.from_json("_c", json_schema))
df = df.select("_c.*")

df.show()
# +----+----+---+
# | bar| baz|foo|
# +----+----+---+
# |   2| aaa|  1|
# |null|null|  1|
# +----+----+---+

It doesn't use neither distinct nor collect. It once calls rdd, so that the extracted schema would have a suitable format to use in from_json.

ZygD
  • 22,092
  • 39
  • 79
  • 102