3

I wonder if there is a more efficient way in spark to find the most frequent value of a set of columns than using rank() in order to use it as an imputation for missing values.

E.g. in spark-sql I could formulate something similar like how to select the most frequently appearing values? per column. This solution works for a single column using rank. What I am looking for is a) a more efficient variant (as the first answer outlines ) and b) something which is more optimal than using a for loop and the solution of a) to apply for multiple columns.

Do you see any possibility to optimize this in spark?

edit

An example. Here is a small Dataset

case class FooBarGG(foo: Int, bar: String, baz: String, dropme: String)
val df = Seq((0, "first", "A", "dropme"), (1, "second", "A", "dropme2"),
    (0, "first", "B", "foo"),
    (1, "first", "C", "foo"))
    .toDF("foo", "bar", "baz", "dropme").as[FooBarGG]
val columnsFactor = Seq("bar", "baz")
val columnsToDrop = Seq("dropme")
val factorCol= (columnsFactor ++ columnsToDrop).map(c => col(c))

With the query from the answer

df.groupBy(factorCol: _*).count.agg(max(struct($"count" +: factorCol: _*)).alias("mostFrequent")).show
+--------------------+
|        mostFrequent|
+--------------------+
|[1,second,A,dropme2]|
+--------------------+
|-- mostFrequent: struct (nullable = true)
 |    |-- count: long (nullable = false)
 |    |-- bar: string (nullable = true)
 |    |-- baz: string (nullable = true)
 |    |-- dropme: string (nullable = true)

Is the result but for column bar -> first, baz -> A and for drompe -> foo are the single top1 most frequent values, which are different from the returned result.

Community
  • 1
  • 1
Georg Heiler
  • 16,916
  • 36
  • 162
  • 292

2 Answers2

3

You can use simple aggregation as long as you fields can be ordered and count is the leading one:

import org.apache.spark.sql.functions._

val df = Seq("John", "Jane", "Eve", "Joe", "Eve").toDF("name")
val grouping = Seq($"name")

df.groupBy(grouping: _*).count.agg(max(struct($"count" +: grouping: _*)))

It is also possible to use a statically typed Dataset:

import org.apache.spark.sql.catalyst.encoders.RowEncoder

df.groupByKey(x => x)(RowEncoder(df.schema)).count.reduce(
  (x, y) => if (x._2 > y._2) x else y
)

You can adjust grouping columns or key function to handle more complex scenarios.

zero323
  • 322,348
  • 103
  • 959
  • 935
0
 // find most frequent value in the grouped columns
  def mode(df: DataFrame, valueColumnName: String, groupByColumns: Seq[String]): DataFrame = {
    df.groupBy(valueColumnName, groupByColumns: _*).count()
      .withColumn(
        "rn",
        row_number().over(Window.partitionBy(groupByColumns.head, groupByColumns.tail: _*).orderBy(col("count").desc))
      )
      .where(col("rn") === 1)
      .select(valueColumnName, groupByColumns: _*)
  }
Vitamon
  • 538
  • 7
  • 18