2

I am trying to create a column in my Spark Dataframe a flag if a column's row is in a separate Dataframe.

This is my main Spark Dataframe (df_main)

+--------+
|main    |
+--------+
|28asA017|
|03G12331|
|1567L044|
|02TGasd8|
|1asd3436|
|A1234567|
|B1234567|
+--------+

This is my reference (df_ref), there are hundreds of rows in this reference so I obviously can't hard code them like this solution or this one

+--------+
|mask_vl |
+--------+
|A1234567|
|B1234567|
...
+--------+

Normally, what I'd do in pandas' dataframe is this:

df_main['is_inref'] = np.where(df_main['main'].isin(df_ref.mask_vl.values), "YES", "NO")

So that I would get this

+--------+--------+
|main |is_inref|
+--------+--------+
|28asA017|NO      |
|03G12331|NO      |
|1567L044|NO      |
|02TGasd8|NO      |
|1asd3436|NO      |
|A1234567|YES     |
|B1234567|YES     |
+--------+--------+

I have tried the following code, but I don't get what the error in the picture means.

df_main = df_main.withColumn('is_inref', "YES" if F.col('main').isin(df_ref) else "NO")
df_main.show(20, False)

Error of the mentioned code

user2552108
  • 1,107
  • 3
  • 15
  • 30

2 Answers2

1

You are close. I think the additional step that you need, is to explicitly create the list that will contain the values from df_ref.

Please see below an illustration:

# Create your DataFrames
df = spark.createDataFrame(["28asA017","03G12331","1567L044",'02TGasd8','1asd3436','A1234567','B1234567'], "string").toDF("main")
df_ref =  spark.createDataFrame(["A1234567","B1234567"], "string").toDF("mask_vl")

Then, you can create a list and use isin, almost as you have it:

# Imports
from pyspark.sql.functions import col, when

# Create a list with the values of your reference DF
mask_vl_list = df_ref.select("mask_vl").rdd.flatMap(lambda x: x).collect()

# Use isin to check whether the values in your column exist in the list
df_main = df_main.withColumn('is_inref', when(col('main').isin(mask_vl_list), 'YES').otherwise('NO'))

This will give you:

>>> df_main.show()

+--------+--------+
|    main|is_inref|
+--------+--------+
|28asA017|      NO|
|03G12331|      NO|
|1567L044|      NO|
|02TGasd8|      NO|
|1asd3436|      NO|
|A1234567|     YES|
|B1234567|     YES|
+--------+--------+
sophocles
  • 13,593
  • 3
  • 14
  • 33
  • Thanks for the quick response @sophocles. Am I correct to understand that the first command (the rdd.flatMap(...).collect()) is to basically convert the dataframe into a list into the main driver node? If so, won't I run into an out of memory exception if the reference becomes huge ? – user2552108 Jul 12 '21 at 09:05
  • Welcome. Yes you are right. I don't think that you will run into memory exception problems as this is an efficient approach. You can check out more information [```here```](https://stackoverflow.com/questions/38610559/convert-spark-dataframe-column-to-python-list) about performance benchmarking to convert a column to a list. – sophocles Jul 12 '21 at 09:12
  • I guess that collect() is not the best solution. If the mask_v1 data frame will grow, it will be a problem. – Elisabetta Jul 12 '21 at 15:34
1

If you want to avoid collect, I advise you to do the next:

df_ref= df_ref
          .withColumnRenamed("mask_v1", "main")
          .withColumn("isPreset", lit("yes"))
      
 main_df= main_df.join(df_ref, Seq("main"), "left_outer")
          .withColumn("is_inref", when(col("isPresent").isNull,
          lit("NO")).otherwise(lit("YES")))
Elisabetta
  • 328
  • 3
  • 9