7

Spark 2.4 introduced the new SQL function slice, which can be used extract a certain range of elements from an array column. I want to define that range dynamically per row, based on an Integer column that has the number of elements I want to pick from that column.

However, simply passing the column to the slice function fails, the function appears to expect integers for start and end values. Is there a way of doing this without writing a UDF?

To visualize the problem with an example: I have a dataframe with an array column arr that has in each of the rows an array that looks like ['a', 'b', 'c']. There also is an end_idx column that has elements 3, 1 and 2:

+---------+-------+
|arr      |end_idx|
+---------+-------+
|[a, b, c]|3      |
|[a, b, c]|1      |
|[a, b, c]|2      |
+---------+-------+

I try to create a new column arr_trimmed like this:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])

df = df.withColumn("arr_trimmed", F.slice(F.col("arr"), 1, F.col("end_idx")))

I expect this code to create the new column with elements ['a', 'b', 'c'], ['a'], ['a', 'b']

Instead I get an error TypeError: Column is not iterable.

harppu
  • 384
  • 4
  • 13
  • Possible duplicate of [Using a column value as a parameter to a spark DataFrame function](https://stackoverflow.com/questions/51140470/using-a-column-value-as-a-parameter-to-a-spark-dataframe-function) – pault Sep 04 '19 at 11:33

2 Answers2

16

You can do it by passing a SQL expression as follows:

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)"))

Here is the whole working example:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]

df = spark.createDataFrame(l, ["arr", "end_idx"])

df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)")).show(truncate=False)

+---------+-------+-----------+
|arr      |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3      |[a, b, c]  |
|[a, b, c]|1      |[a]        |
|[a, b, c]|2      |[a, b]     |
+---------+-------+-----------+
David Vrba
  • 2,984
  • 12
  • 16
2

As of Spark 2.4.0, slice receives columns as arguments. Therefore it can be used as follows:

df.withColumn("arr_trimmed", F.slice(arr, F.lit(1), end_idx))

David Vrba's example can be rewritten this way:

import pyspark.sql.functions as F

l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]

df = spark.createDataFrame(l, ["arr", "end_idx"])

df.withColumn("arr_trimmed", F.slice("arr", F.lit(1), F.col("end_idx"))).show(truncate=False)


+---------+-------+-----------+
|arr      |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3      |[a, b, c]  |
|[a, b, c]|1      |[a]        |
|[a, b, c]|2      |[a, b]     |
+---------+-------+-----------+
  • This answer is correct and should be accepted as best, with the following clarification - `slice` accepts columns as arguments, as long as both `start` and `length` are given as column expressions. If for example `start` is given as an integer without `lit()`, as in the original question, I get `py4j.Py4JException: Method slice([class org.apache.spark.sql.Column, class java.lang.Integer, class org.apache.spark.sql.Column]) does not exist` (as late as Spark 3.2.1). – Vic Oct 29 '22 at 00:02