0

I have a data frame like this

+------------+-----------------+-------------------------------+
| Name       |   Age           | Answers                       |
+------------+-----------------+-------------------------------+
| Maria      | 23              | [apple, mango, orange, banana]| 
| John       | 55              | [apple, orange, banana]       |
| Brad       | 44              | [banana]                      |
+------------+-----------------+-------------------------------+

The answers column contains an array of elements

Expected Output

+------------+-----------------+-------------------------------+
| Name       |   Age           | apple | mango |orange| banana |
+------------+-----------------+-------------------------------+
| Maria      | 23              |  True |  True | True | True   |
| John       | 55              |  True |  False| True | True   |
| Brad       | 44              | False | False | False| True   |
+------------+-----------------+-------------------------------+

Is there a way where I can convert the array column into True and False columns?

Thanks in advance.

ar_mm18
  • 415
  • 2
  • 8

2 Answers2

1

If you don't know in advance all the possible values of the Answers array, you can resort to the following solution that uses explode + pivot.

df \
    .withColumn("answer", F.explode("Answers")) \
    .drop("Answers") \
    .groupBy("Name", "Age") \
    .pivot("answer") \
    .agg(F.first("answer").isNotNull()) \
    .na \
    .fill(False)

It is heavy, as is the pivot if you don't know in advantage the possible values.


Solution explained

1. Array explode

.withColumn("answer", F.explode("Answers")) \
.drop("Answers")

The explode function will create N different rows for each value in the Answers array.

+-----+---+--------+                                                              
| Name|Age|  answer|
+-----+---+--------+
|Maria| 23|   apple|
|Maria| 23|   mango|
|Maria| 23|  orange|
|Maria| 23|  banana|
| John| 55|   apple|
| John| 55|  orange|
| John| 55|  banana|
| Brad| 44|  banana|
+-----+---+--------+

2. Pivot

.groupBy("Name", "Age") \
.pivot("answer") \
.agg(F.first("answer").isNotNull())

Transpose rows in columns by creating a column for each distinct value in the answer column.

It would be safer if you had a column that uniquely identifies each row (i.e. an id) and you used that in the groupBy.

+-----+---+-----+------+-----+------+
| Name|Age|apple|banana|mango|orange|
+-----+---+-----+------+-----+------+
| John| 55| true|  true| null|  true|
|Maria| 23| true|  true| true|  true|
| Brad| 44| null|  true| null|  null|
+-----+---+-----+------+-----+------+

3. Fix missing values

Missing values are reported as null during the pivot. Replace all nulls with false as required.

.na \
.fill(False)
+-----+---+-----+------+-----+------+
| Name|Age|apple|banana|mango|orange|
+-----+---+-----+------+-----+------+
| John| 55| true|  true|false|  true|
|Maria| 23| true|  true| true|  true|
| Brad| 44|false|  true|false| false|
+-----+---+-----+------+-----+------+
vinsce
  • 1,271
  • 1
  • 10
  • 19
  • @vinse why did you use the groupBy statement? I am asking this because my original data frame has 400 columns and I just want to explode the one array column so that I could use it further – ar_mm18 Nov 01 '22 at 19:29
  • Also, I get this error when i recreate the ode u provided - `A DataFrame object does not have an attribute pivot. Please check the spelling and/or the datatype of the object.` Do i have to import something for pivot to work – ar_mm18 Nov 01 '22 at 19:30
  • 1
    `groupBy` is required in order to use `pivot`. If you don't use `groupBy` you get the error you just observed because `pivot` is part of the `GroupedData` class and not `DataFrame` class. Take a look [here](https://stackoverflow.com/questions/49392683/transpose-dataframe-without-aggregation-in-spark-with-scala) for ideas on how to pivot without grouping. – vinsce Nov 01 '22 at 19:36
  • 1
    Thank you very much for the detailed explaination. Really appreciate your help @vinsce. – ar_mm18 Nov 01 '22 at 19:47
  • `Column name "temp column" contains invalid character(s). Please use alias to rename it.` I get this error when i implement this code. I am unable to figure out the issue. `Step 1` works perfectly fine. For using group by I created an id column so that I can use pivot function. – ar_mm18 Nov 01 '22 at 21:23
  • If you have invalid characters in the values in the array (that now are used as column names) you can transform them (transform the `answer` column) just before the `groupBy` to fix them. You can use the escape strategy as suggested in another answer or just replace invalid characters using `regexp_replace` or a custom UDF – vinsce Nov 02 '22 at 19:14
1

A possible solution, knowing the list of all the possible answers, is to create a column for each of them, stating if the column 'Answers' contains that particular answer for that row.

Suppose that the list of possible answers is called possible_answers (so, in your case is ['orange', 'apple', 'mango', 'banana']), then the following code produces the DataFrame you want (suppose that df is your input DataFrame):

import re
from pyspark.sql import functions as F

def normalize_column_name(name):
    """Normalize column name backticking names with invalid characters"""
    return (f'`{name}`' if re.search(r'[_|\.|\(|\/]', name) else name)
    # if you prefer, you can replace the invalid characters with a valid one, e.g., '-'
    # return re.sub(r"(_|\.|\(|\/)", "_", name)

for c in sorted(possible_answers):  # sorted is optional, but guarantees the order of the columns
    df = df.withColumn(normalize_column_name(c), F.array_contains('Answers', c))
df = df.drop('Answers')

If you do not know beforehand all the possible answers, you can infer them from the input DataFrame:

possible_answers = [r[0] for r in df.select(F.explode('Answers')).distinct().collect()]

Here an example with a DataFrame containing an answer with invalid characters for a column name:

+-----+---+----------------------------------------+
|Name |Age|Answers                                 |
+-----+---+----------------------------------------+
|Maria|23 |[apple, mango, orange, banana]          |
|John |55 |[apple, orange, banana, lime.strawberry]|
|Brad |44 |[banana, nut/ pear]                     |
+-----+---+----------------------------------------+

and here the result:

+-----+---+-----+------+-----------------+-----+-----------+------+
|Name |Age|apple|banana|`lime.strawberry`|mango|`nut/ pear`|orange|
+-----+---+-----+------+-----------------+-----+-----------+------+
|Maria|23 |true |true  |false            |true |false      |true  |
|John |55 |true |true  |true             |false|false      |true  |
|Brad |44 |false|true  |false            |false|true       |false |
+-----+---+-----+------+-----------------+-----+-----------+------+
PieCot
  • 3,564
  • 1
  • 12
  • 20
  • I get this error - `Column name "nut/ pear" contains invalid character(s). Please use alias to rename it` - I have a value in the answers array like that - is it because of the `/` character. How can i overcome this – ar_mm18 Nov 01 '22 at 21:27
  • If your answers contain characters that are not allowed in a column name, you have to change it. You can use `replace`, using something like this: `df.withColumn(c.replace('/', '-'), ...)`. – PieCot Nov 01 '22 at 22:21