3

I have a Spark DataFrame where all fields are integer type. I need to count how many individual cells are greater than 0.

I am running locally and have a DataFrame with 17,000 rows and 450 columns.

I have tried two methods, both yielding slow results:

Version 1:

(for (c <- df.columns) yield df.where(s"$c > 0").count).sum

Version 2:

df.columns.map(c => df.filter(df(c) > 0).count)

This calculation takes 80 seconds of wall clock time. With Python Pandas, it takes a fraction of second. I am aware that for small data sets and local operation, Python may perform better, but this seems extreme.

Trying to make a Spark-to-Spark comparison, I find that running MLlib's PCA algorithm on the same data (converted to a RowMatrix) takes less than 2 seconds!

Is there a more efficient implementation I should be using?

If not, how is the seemingly much more complex PCA calculation so much faster?

stefanobaghino
  • 11,253
  • 4
  • 35
  • 63
andrew
  • 3,929
  • 1
  • 25
  • 38

2 Answers2

4

What to do

import org.apache.spark.sql.functions.{col, count, when}

df.select(df.columns map (c => count(when(col(c) > 0, 1)) as c): _*)

Why

Your both attempts create number of jobs proportional to the number of columns. Computing the execution plan and scheduling the job alone are expensive and add significant overhead depending on the amount of data.

Furthermore, data might be loaded from disk and / or parsed each time the job is executed, unless data is fully cached with significant memory safety margin which ensures that the cached data will not be evicted.

This means that in the worst case scenario nested-loop-like structure you use can roughly quadratic in terms of the number of columns.

The code shown above handles all columns at the same time, requiring only a single data scan.

zero323
  • 322,348
  • 103
  • 959
  • 935
  • 1
    Although I selected Raphael's answer, this solution also works. Both approaches take about 2.5 seconds wall clock – andrew Jul 16 '18 at 16:22
2

The problem with your approach is that the file is scanned for every column (unless you have cached it in memory). The fastet way with a single FileScan should be:

import org.apache.spark.sql.functions.{explode,array}

val cnt: Long = df
  .select(
    explode(
      array(df.columns.head,df.columns.tail:_*)
    ).as("cell")
  )
.where($"cell">0).count

Still I think it will be slower than with Pandas, as Spark has a certain overhead due to the parallelization engine

Raphael Roth
  • 26,751
  • 15
  • 88
  • 145