2

I have a spark dataframe and one of its fields is an array of Row structures. I need to expand it into their own columns. One of the problems is in the array, sometimes a field is missing.

The following is an example:

from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql import functions as udf

spark = SparkSession.builder.getOrCreate()

# data
rows = [{'status':'active','member_since':1990,'info':[Row(tag='name',value='John'),Row(tag='age',value='50'),Row(tag='phone',value='1234567')]},
        {'status':'inactive','member_since':2000,'info':[Row(tag='name',value='Tom'),Row(tag='phone',value='1234567')]},
        {'status':'active','member_since':2015,'info':[Row(tag='name',value='Steve'),Row(tag='age',value='28')]}]

# create dataframe
df = spark.createDataFrame(rows)

# transform info to dict
to_dict = udf.UserDefinedFunction(lambda s:dict(s),MapType(StringType(),StringType()))
df = df.withColumn("info_dict",to_dict("info"))

# extract name, NA if not exists
extract_name = udf.UserDefinedFunction(lambda s:s.get("name","NA"))
df = df.withColumn("name",extract_name("info_dict"))

# extract age, NA if not exists
extract_age = udf.UserDefinedFunction(lambda s:s.get("age","NA"))
df = df.withColumn("age",extract_age("info_dict"))

# extract phone, NA if not exists
extract_phone = udf.UserDefinedFunction(lambda s:s.get("phone","NA"))
df = df.withColumn("phone",extract_phone("info_dict"))

df.show()

You can see for 'Tom', 'age' is missing; for 'Steve', 'phone' is missing. Like the above code snippet, my current solution is to first transform the array into dict and then parse each individual field into their column. The result is like this:

+--------------------+------------+--------+--------------------+-----+---+-------+
|                info|member_since|  status|           info_dict| name|age|  phone|
+--------------------+------------+--------+--------------------+-----+---+-------+
|[[name, John], [a...|        1990|  active|[name -> John, ph...| John| 50|1234567|
|[[name, Tom], [ph...|        2000|inactive|[name -> Tom, pho...|  Tom| NA|1234567|
|[[name, Steve], [...|        2015|  active|[name -> Steve, a...|Steve| 28|     NA|
+--------------------+------------+--------+--------------------+-----+---+-------+

I really just want the columns 'status','member_since','name', 'age' and 'phone'. This solution works but rather slow because of the UDF. Is there any faster alternatives? Thanks

pault
  • 41,343
  • 15
  • 107
  • 149
Bo Qiang
  • 739
  • 2
  • 13
  • 34

1 Answers1

0

I can think of 2 ways to do this using DataFrame functions. I believe the first one should be faster, but the code is much less elegant. The second is more compact, but probably slower.

Method 1: Create Map Dynamically

The heart of this method is to turn your Row into a MapType(). This can be achieved using pyspark.sql.functions.create_map() and some magic using functools.reduce() and operator.add().

from operator import add
import pyspark.sql.functions as f

f.create_map(
    *reduce(
        add,
        [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
         for k in range(3)]
    )
)

The problem is that there isn't a way (AFAIK) to dynamically determine the length of the WrappedArray or iterate through it in an easy way. If a value is missing, this will cause an error because map keys can not be null. However since we know that the list can either contain 1, 2, 3 elements, we can just test for each of these cases.

df.withColumn(
    'map',
    f.when(f.size(f.col('info')) == 1, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(1)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 2, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(2)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 3, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(3)]
            )
        )
    )))
).select(
    ['member_since', 'status'] + [f.col("map").getItem(k).alias(k) for k in keys]
).show(truncate=False)

The last step turns the 'map' keys into columns using the method described in this answer.

This produces the following output:

+------------+--------+-----+----+-------+
|member_since|status  |name |age |phone  |
+------------+--------+-----+----+-------+
|1990        |active  |John |50  |1234567|
|2000        |inactive|Tom  |null|1234567|
|2015        |active  |Steve|28  |null   |
+------------+--------+-----+----+-------+

Method 2: Use explode, groupBy and pivot

First use pyspark.sql.functions.explode() on the column 'info', and then use the 'tag' and 'value' columns as arguments to create_map():

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', 'map')\
    .show(truncate=False)
#+------------+------------+--------+---------------------+
#|id          |member_since|status  |map                  |
#+------------+------------+--------+---------------------+
#|85899345920 |1990        |active  |Map(name -> John)    |
#|85899345920 |1990        |active  |Map(age -> 50)       |
#|85899345920 |1990        |active  |Map(phone -> 1234567)|
#|180388626432|2000        |inactive|Map(name -> Tom)     |
#|180388626432|2000        |inactive|Map(phone -> 1234567)|
#|266287972352|2015        |active  |Map(name -> Steve)   |
#|266287972352|2015        |active  |Map(age -> 28)       |
#+------------+------------+--------+---------------------+

I also added a column 'id' using pyspark.sql.functions.monotonically_increasing_id() to make sure we can keep track of which rows belong to the same record.

Now we can explode the map column, groupBy(), and pivot(). We can use pyspark.sql.functions.first() as the aggregate function for the groupBy() because we know there will only be one 'value' in each group.

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', f.explode('map'))\
    .groupBy('id', 'member_since', 'status').pivot('key').agg(f.first('value'))\
    .select('member_since', 'status', 'age', 'name', 'phone')\
    .show()
#+------------+--------+----+-----+-------+
#|member_since|  status| age| name|  phone|
#+------------+--------+----+-----+-------+
#|        1990|  active|  50| John|1234567|
#|        2000|inactive|null|  Tom|1234567|
#|        2015|  active|  28|Steve|   null|
#+------------+--------+----+-----+-------+
pault
  • 41,343
  • 15
  • 107
  • 149