3

I have a dataframe of format given below.

movieId1 | genreList1              | genreList2
--------------------------------------------------
1        |[Adventure,Comedy]       |[Adventure]
2        |[Animation,Drama,War]    |[War,Drama]
3        |[Adventure,Drama]        |[Drama,War]

and trying to create another flag column which shows whether genreList2 is a subset of genreList1.

movieId1 | genreList1              | genreList2        | Flag
---------------------------------------------------------------
1        |[Adventure,Comedy]       | [Adventure]       |1
2        |[Animation,Drama,War]    | [War,Drama]       |1
3        |[Adventure,Drama]        | [Drama,War]       |0

I have tried this:

def intersect_check(a: Array[String], b: Array[String]): Int = {
  if (b.sameElements(a.intersect(b))) { return 1 } 
  else { return 2 }
}

def intersect_check_udf =
  udf((colvalue1: Array[String], colvalue2: Array[String]) => intersect_check(colvalue1, colvalue2))

data = data.withColumn("Flag", intersect_check_udf(col("genreList1"), col("genreList2")))

But this throws error

org.apache.spark.SparkException: Failed to execute user defined function.

P.S.: The above function (intersect_check) works for Arrays.

ZygD
  • 22,092
  • 39
  • 79
  • 102

5 Answers5

6

We can define an udf that calculates the length of the intersection between the two Array columns and checks whether it is equal to the length of the second column. If so, the second array is a subset of the first one.

Also, the inputs of your udf need to be class WrappedArray[String], not Array[String] :

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

val same_elements = udf { (a: WrappedArray[String], 
                           b: WrappedArray[String]) => 
  if (a.intersect(b).length == b.length){ 1 }else{ 0 }  
}

df.withColumn("test",same_elements(col("genreList1"),col("genreList2")))
  .show(truncate = false)
+--------+-----------------------+------------+----+
|movieId1|genreList1             |genreList2  |test|
+--------+-----------------------+------------+----+
|1       |[Adventure, Comedy]    |[Adventure] |1   |
|2       |[Animation, Drama, War]|[War, Drama]|1   |
|3       |[Adventure, Drama]     |[Drama, War]|0   |
+--------+-----------------------+------------+----+

Data

val df = List((1,Array("Adventure","Comedy"), Array("Adventure")),
              (2,Array("Animation","Drama","War"), Array("War","Drama")),
              (3,Array("Adventure","Drama"),Array("Drama","War"))).toDF("movieId1","genreList1","genreList2")
mtoto
  • 23,919
  • 4
  • 58
  • 71
3

Here is the solution converting using subsetOf

  val spark =
    SparkSession.builder().master("local").appName("test").getOrCreate()

  import spark.implicits._

  val data = spark.sparkContext.parallelize(
  Seq(
    (1,Array("Adventure","Comedy"),Array("Adventure")),
  (2,Array("Animation","Drama","War"),Array("War","Drama")),
  (3,Array("Adventure","Drama"),Array("Drama","War"))
  )).toDF("movieId1", "genreList1", "genreList2")


  val subsetOf = udf((col1: Seq[String], col2: Seq[String]) => {
    if (col2.toSet.subsetOf(col1.toSet)) 1 else 0
  })

  data.withColumn("flag", subsetOf(data("genreList1"), data("genreList2"))).show()

Hope this helps!

koiralo
  • 22,594
  • 6
  • 51
  • 72
1

Spark 3.0+ (forall)

forall($"genreList2", x => array_contains($"genreList1", x)).cast("int")

Full example:

val df = Seq(
     (1, Seq("Adventure", "Comedy"), Seq("Adventure")),
     (2, Seq("Animation", "Drama","War"), Seq("War", "Drama")),
     (3, Seq("Adventure", "Drama"), Seq("Drama", "War"))
     ).toDF("movieId1", "genreList1", "genreList2")

val df2 = df.withColumn("Flag", forall($"genreList2", x => array_contains($"genreList1", x)).cast("int"))

df2.show()
// +--------+--------------------+------------+----+
// |movieId1|          genreList1|  genreList2|Flag|
// +--------+--------------------+------------+----+
// |       1| [Adventure, Comedy]| [Adventure]|   1|
// |       2|[Animation, Drama...|[War, Drama]|   1|
// |       3|  [Adventure, Drama]|[Drama, War]|   0|
// +--------+--------------------+------------+----+
ZygD
  • 22,092
  • 39
  • 79
  • 102
0

One solution may be to exploit spark array builtin functions: genreList2 is subset of genreList1 if the intersection between the two is equal to genreList2. In the code below a sort_array operation has been added to avoid a mismatch between two arrays with different ordering but same elements.

val spark = {
    SparkSession
    .builder()
    .master("local")
    .appName("test")
    .getOrCreate()
}

import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._

val df = Seq(
    (1, Array("Adventure","Comedy"), Array("Adventure")),
    (2, Array("Animation","Drama","War"), Array("War","Drama")),
    (3, Array("Adventure","Drama"), Array("Drama","War"))
).toDF("movieId1", "genreList1", "genreList2")

df
.withColumn("flag",
 sort_array(array_intersect($"genreList1",$"genreList2"))
 .equalTo(
   sort_array($"genreList2")
 )
.cast("integer")
)
.show()

The output is

+--------+--------------------+------------+----+
|movieId1|          genreList1|  genreList2|flag|
+--------+--------------------+------------+----+
|       1| [Adventure, Comedy]| [Adventure]|   1|
|       2|[Animation, Drama...|[War, Drama]|   1|
|       3|  [Adventure, Drama]|[Drama, War]|   0|
+--------+--------------------+------------+----+
Galuoises
  • 2,630
  • 24
  • 30
0

This can also work here and it does not use udf

 import spark.implicits._
 val data = Seq(
        (1,Array("Adventure","Comedy"),Array("Adventure")),
        (2,Array("Animation","Drama","War"),Array("War","Drama")),
        (3,Array("Adventure","Drama"),Array("Drama","War"))
      ).toDF("movieId1", "genreList1", "genreList2")

 data
     .withColumn("size",size(array_except($"genreList2",$"genreList1")))
     .withColumn("flag",when($"size" === lit(0), 1) otherwise(0))
     .show(false)
whoisthis
  • 33
  • 8