17

I have a spark dataframe with rows as -

1   |   [a, b, c]
2   |   [d, e, f]
3   |   [g, h, i]

Now I want to keep only the first 2 elements from the array column.

1   |   [a, b]
2   |   [d, e]
3   |   [g, h]

How can that be achieved?

Note - Remember that I am not extracting a single array element here, but a part of the array which may contain multiple elements.

Vipul Sharma
  • 768
  • 3
  • 11
  • 25
  • Possible duplicate of [How to extract an element from a array in pyspark](https://stackoverflow.com/questions/45254928/how-to-extract-an-element-from-a-array-in-pyspark) – pault Oct 24 '18 at 19:12
  • 1
    I already saw that answer, but that is not what I want. I don't want a single item from array, rather I am looking for first N elements. – Vipul Sharma Oct 25 '18 at 06:09
  • @pault interestingly enough, linked solution does not seem to work with Spark 2.3.1 (throws exception). Any ideas? – desertnaut Oct 25 '18 at 09:51
  • @pault mystery solved! A new user had decided to alter OP's code in linked answer, rendering it wrong (restored it)... – desertnaut Oct 25 '18 at 10:18
  • https://stackoverflow.com/questions/47585279/how-to-access-values-in-array-column not convinced we need to create a tempview for this – thebluephantom Oct 25 '18 at 14:00

2 Answers2

30

Here's how to do it with the API functions.

Suppose your DataFrame were the following:

df.show()
#+---+---------+
#| id|  letters|
#+---+---------+
#|  1|[a, b, c]|
#|  2|[d, e, f]|
#|  3|[g, h, i]|
#+---+---------+

df.printSchema()
#root
# |-- id: long (nullable = true)
# |-- letters: array (nullable = true)
# |    |-- element: string (containsNull = true)

You can use square brackets to access elements in the letters column by index, and wrap that in a call to pyspark.sql.functions.array() to create a new ArrayType column.

import pyspark.sql.functions as f

df.withColumn("first_two", f.array([f.col("letters")[0], f.col("letters")[1]])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

Or if you had too many indices to list, you can use a list comprehension:

df.withColumn("first_two", f.array([f.col("letters")[i] for i in range(2)])).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

For pyspark versions 2.4+ you can also use pyspark.sql.functions.slice():

df.withColumn("first_two",f.slice("letters",start=1,length=2)).show()
#+---+---------+---------+
#| id|  letters|first_two|
#+---+---------+---------+
#|  1|[a, b, c]|   [a, b]|
#|  2|[d, e, f]|   [d, e]|
#|  3|[g, h, i]|   [g, h]|
#+---+---------+---------+

slice may have better performance for large arrays (note that start index is 1, not 0)

Hugo Zaragoza
  • 574
  • 8
  • 25
pault
  • 41,343
  • 15
  • 107
  • 149
  • 1
    Damn... As I said, I have gone rusty - could not even remember the *existence* of `pyspark.sql.functions.array`... :( – desertnaut Oct 25 '18 at 14:12
  • This doesn't work for me for a similar problem. I get the following error: "Can't extract value from probability#6225: need struct type but got struct,values:array>;" – LePuppy May 22 '19 at 06:45
  • 1
    @LePuppyle my guess is that you have a VectorUDT not an array. For thst, you will need a `udf` - try [this post](https://stackoverflow.com/questions/39555864/how-to-access-element-of-a-vectorudt-column-in-a-spark-dataframe). – pault May 22 '19 at 10:08
  • `AnalysisException: "Field name should be String Literal, but it's 0;"` – rjurney Sep 27 '20 at 01:02
4

Either my pyspark skills have gone rusty (I confess I don't hone them much anymore nowadays), or this is a tough nut indeed... The only way I managed to do it is by using SQL statements:

spark.version
#  u'2.3.1'

# dummy data:

from pyspark.sql import Row
x = [Row(col1="xx", col2="yy", col3="zz", col4=[123,234, 456])]
rdd = sc.parallelize(x)
df = spark.createDataFrame(rdd)
df.show()
# result:
+----+----+----+---------------+
|col1|col2|col3|           col4|
+----+----+----+---------------+
|  xx|  yy|  zz|[123, 234, 456]|
+----+----+----+---------------+

df.createOrReplaceTempView("df")
df2 = spark.sql("SELECT col1, col2, col3, (col4[0], col4[1]) as col5 FROM df")
df2.show()
# result:
+----+----+----+----------+ 
|col1|col2|col3|      col5|
+----+----+----+----------+ 
|  xx|  yy|  zz|[123, 234]|
+----+----+----+----------+

For future questions, it would be good to follow the suggested guidelines on How to make good reproducible Apache Spark Dataframe examples.

desertnaut
  • 57,590
  • 26
  • 140
  • 166