2

I have a pyspark datarame as follows:

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import udf

schema = T.StructType([  # schema
    T.StructField("id", T.StringType(), True),
    T.StructField("code", T.ArrayType(T.StringType()), True)])
df = spark.createDataFrame([{"id": "1", "code": ["a1", "a2","a3","a4"]},
                            {"id": "2", "code": ["b1","b2"]},
                            {"id": "3", "code": ["c1","c2","c3"]},
                            {"id": "4", "code": ["d1", "b3"]}],
                           schema=schema)

which gives output

df.show()


| id|            code|
|---|----------------|
|  1|[a1, a2, a3, a4]|
|  2|        [b1, b2]|
|  3|    [c1, c2, c3]|
|  4|        [d1, b3]|

I would like to be able to filter rows by supplying a column and list to a function and returns true if any interesection (using disjoint from here as there will be many non hits)

def lst_intersect(data_lst,query_lst):
    return not set(data_lst).isdisjoint(query_lst) 
lst_intersect_udf = F.udf(lambda x,y: lst_intersect(x,y), T.BooleanType())

When I try to apply this

query_lst = ['a1','b3']
df = df.withColumn("code_found", lst_intersect_udf(F.col('code'),F.lit(query_lst)))

Get the following error

Unsupported literal type class java.util.ArrayList [a1, b3]

I can solve it by changing the function etc - but wondering is there something fundamental that I doing wrong with the F.lit(query_lst)?

mck
  • 40,932
  • 13
  • 35
  • 50
iboboboru
  • 1,112
  • 2
  • 10
  • 21

1 Answers1

2

lit only accepts a single value, not a Python list. You need to pass in an array column containing literal values from your list, using a list comprehension, for example.

df2 = df.withColumn(
    "code_found", 
    lst_intersect_udf(
        F.col('code'),
        F.array(*[F.lit(i) for i in query_lst])
    )
)

df2.show()
+---+----------------+----------+
| id|            code|code_found|
+---+----------------+----------+
|  1|[a1, a2, a3, a4]|      true|
|  2|        [b1, b2]|     false|
|  3|    [c1, c2, c3]|     false|
|  4|        [d1, b3]|      true|
+---+----------------+----------+

That said, if you have Spark >= 2.4, you can also use the Spark SQL function arrays_overlap to give a better performance:

df2 = df.withColumn(
    "code_found", 
    F.arrays_overlap(
        F.col('code'),
        F.array(*[F.lit(i) for i in query_lst])
    )
)
mck
  • 40,932
  • 13
  • 35
  • 50