I am trying to solve a data cleaning step in a Machine Learning problem where I should group all the elements in the long tail in a common category named "Others". For example, I have a dataframe like this:
val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")
I want to keep the categories "ABC"
and "FPK"
since they appear several times, but I don't want to have one different category for: 123,980,abc
Since they appear just once. So What I would like to have instead is:
+---+------+
| n| s|
+---+------+
| 1| ABC|
| 2| ABC|
| 3|Others|
| 4| FPK|
| 5| FPK|
| 6| ABC|
| 7| ABC|
| 8|Others|
| 9|Others|
| 10| FPK|
+---+------+
To achieve this what I tried is this:
val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)
This works fine.
But I would like to programatically decide what categories belong to the long tail, in my case appear just once in the originall dataframe. So I wrote this to create a dataframe with those categories that only appear once:
val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)
+---+---+
| s|cnt|
+---+---+
|980| 1|
|abc| 1|
|123| 1|
+---+---+
Now I was trying to convert the values of the column "s" in this longTail dataset into a List to exchange it by the one I hardcoded before. So I tried with:
val ar = longTail.select("s").collect().map(_(0)).toList
ar: List[Any] = List(123, 980, abc)
But when I try to add the ar
val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))
I get the following error:
java.lang.RuntimeException: Unsupported literal type class scala.collection.immutable.$colon$colon List(123, 980, abc)
What am I missing?