I have a pyspark datarame as follows:
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import udf
schema = T.StructType([ # schema
T.StructField("id", T.StringType(), True),
T.StructField("code", T.ArrayType(T.StringType()), True)])
df = spark.createDataFrame([{"id": "1", "code": ["a1", "a2","a3","a4"]},
{"id": "2", "code": ["b1","b2"]},
{"id": "3", "code": ["c1","c2","c3"]},
{"id": "4", "code": ["d1", "b3"]}],
schema=schema)
which gives output
df.show()
| id| code|
|---|----------------|
| 1|[a1, a2, a3, a4]|
| 2| [b1, b2]|
| 3| [c1, c2, c3]|
| 4| [d1, b3]|
I would like to be able to filter rows by supplying a column and list to a function and returns true if any interesection (using disjoint from here as there will be many non hits)
def lst_intersect(data_lst,query_lst):
return not set(data_lst).isdisjoint(query_lst)
lst_intersect_udf = F.udf(lambda x,y: lst_intersect(x,y), T.BooleanType())
When I try to apply this
query_lst = ['a1','b3']
df = df.withColumn("code_found", lst_intersect_udf(F.col('code'),F.lit(query_lst)))
Get the following error
Unsupported literal type class java.util.ArrayList [a1, b3]
I can solve it by changing the function etc - but wondering is there something fundamental that I doing wrong with the F.lit(query_lst)
?