I've been pulling my hair out trying to optimize a Spark script and it's still unbearably slow (24min for 600MB of data). The full code is here but I'll try to summarize in this question; please let me know if you see any ways to speed it up.
Hardware: single machine with 256GB memory & 32-core CPU => need to use local as master. For my project, need both local
and local[*]
but let's focus on local
Data: 2 NetCDF files (columnar data); single machine => no HDFS
Parsing the data: read all columns as Arrays
-> ss.parallelize
+ zip
-> convert to DataFrame
Actions: show()
, summary(min, max, mean, stddev)
, write
, groupBy()
, write
How I run: sbt assembly
to create fat jar excluding only spark itself +
spark-submit --master "local" --conf "spark.sql.shuffle.partitions=4" --driver-memory "10g" target/scala-2.11/spark-assembly-1.0.jar --partitions 4 --input ${input} --slice ${slice}
Optimizations I tried:
- parallelize the RDDs to the same number of partitions, also set the default DataFrame partitions to the same number to minimize data movement => seemed to help
- different partitions numbers => 1 just seems to freeze, more than 4 seems to slow it down (obeying rules of numPartitions=~4x number of cores and numPartitions=~data/128MB)
- read all data to driver as Scala Arrays -> transpose -> single RDD (as opposed to zipping RDDs) => slower
- repartition the just-read DataFrames on same columns and numPartitions so the join doesn't trigger a shuffle
- caching DataFrames that get re-used
Code (few renames and comments removed):
private def readDataRDD(path: String, ss: SparkSession, dims: List[String], createIndex: Boolean, numPartitions: Int): DataFrame = {
val file: NetcdfFile = NetcdfFile.open(path)
val vars: util.List[Variable] = file.getVariables
// split variables into dimensions and regular data
val dimVars: Map[String, Variable] = vars.filter(v => dims.contains(v.getShortName)).map(v => v.getShortName -> v).toMap
val colVars: Map[String, Variable] = vars.filter(v => !dims.contains(v.getShortName)).map(v => v.getShortName -> v).toMap
val lon: Array[Float] = readVariable(dimVars(dims(0)))
val lat: Array[Float] = readVariable(dimVars(dims(1)))
val tim: Array[Float] = readVariable(dimVars(dims(2)))
val dimsCartesian: Array[ListBuffer[_]] = cartesian(lon, lat, tim)
// create the rdd with the dimensions (by transposing the cartesian product)
var tempRDD: RDD[ListBuffer[_]] = ss.sparkContext.parallelize(dimsCartesian, numPartitions)
// gather the names of the columns (in order)
val names: ListBuffer[String] = ListBuffer(dims: _*)
for (col <- colVars) {
tempRDD = tempRDD.zip(ss.sparkContext.parallelize(readVariable(col._2), numPartitions)).map(t => t._1 :+ t._2)
names.add(col._1)
}
if (createIndex) {
tempRDD = tempRDD.zipWithIndex().map(t => t._1 :+ t._2.asInstanceOf[Float])
names.add("index")
}
val finalRDD: RDD[Row] = tempRDD.map(Row.fromSeq(_))
val df: DataFrame = ss.createDataFrame(finalRDD, StructType(names.map(StructField(_, FloatType, nullable = false))))
val floatTimeToString = udf((time: Float) => {
val udunits = String.valueOf(time.asInstanceOf[Int]) + " " + UNITS
CalendarDate.parseUdunits(CALENDAR, udunits).toString.substring(0, 10)
})
df.withColumn("time", floatTimeToString(df("time")))
}
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder
.appName("Spark Pipeline")
.getOrCreate()
val dimensions: List[String] = List("longitude", "latitude", "time")
val numberPartitions = options('partitions).asInstanceOf[Int]
val df1: DataFrame = readDataRDD(options('input) + "data1.nc", spark, dimensions, createIndex = true, numberPartitions)
.repartition(numberPartitions, col("longitude"), col("latitude"), col("time"))
val df2: DataFrame = readDataRDD(options('input) + "data2.nc", spark, dimensions, createIndex = false, numberPartitions)
.repartition(numberPartitions, col("longitude"), col("latitude"), col("time"))
var df: DataFrame = df1.join(df2, dimensions, "inner").cache()
println(df.show())
val slice: Array[String] = options('slice).asInstanceOf[String].split(":")
df = df.filter(df("index") >= slice(0).toFloat && df("index") < slice(1).toFloat)
.filter(df("tg") =!= -99.99f && df("pp") =!= -999.9f && df("rr") =!= -999.9f)
.drop("pp_stderr", "rr_stderr", "index")
.withColumn("abs_diff", abs(df("tx") - df("tn"))).cache()
val df_agg = df.drop("longitude", "latitude", "time")
.summary("min", "max", "mean", "stddev")
.coalesce(1)
.write
.option("header", "true")
.csv(options('output) + "agg")
val computeYearMonth = udf((time: String) => {
time.substring(0, 7).replace("-", "")
})
df = df.withColumn("year_month", computeYearMonth(df("time")))
val columnsToAgg: Array[String] = Array("tg", "tn", "tx", "pp", "rr")
val groupOn: Seq[String] = Seq("longitude", "latitude", "year_month")
val grouped_df: DataFrame = df.groupBy(groupOn.head, groupOn.drop(1): _*)
.agg(columnsToAgg.map(column => column -> "mean").toMap)
.drop("longitude", "latitude", "year_month")
val columnsToSum: Array[String] = Array("tg_mean", "tn_mean", "tx_mean", "rr_mean", "pp_mean")
grouped_df
.agg(columnsToSum.map(column => column -> "sum").toMap)
.coalesce(1)
.write
.option("header", "true")
.csv(options('output) + "grouped")
spark.stop()
}
Any ideas how to speed it up further?
Notes:
local
takes 24min;local[32]
takes 5min- yes, Spark isn't built for 1 machine but the same operations (single-threaded) in java or pandas take 10s and 40s, respectively; huge difference
- can't currently view the web interface to visualize the tasks
- the 600MB data is a subset; full dataset is ~50GB