0

I'm trying to use Spark to learn multiclass logistic regression on a windowed text file. What I'm doing is first creating windows and explode them into $"word_winds". Then move the center word of each window into $"word". To fit the LogisticRegression model, I convert each different word into a class ($"label"), thereby it learns. I count the different labels to prone those with few minF samples.

The problem is that some part of the code is very very slow, even for small input files (you can use some README file to test the code). Googling, some users have been experiencing slowness by using explode. They suggest some modifications to the code in order to speed up 2x. However, I think that with a 100MB input file, this wouldn't be sufficient. Please suggest something different, probably to avoid actions that slow down the code. I'm using Spark 2.4.0 and sbt 1.2.8 on a 24-core machine.

import org.apache.spark.sql.functions._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.types._



object SimpleApp {
  def main(args: Array[String]) {

val spark = SparkSession.builder().getOrCreate()
import spark.implicits._

spark.sparkContext.setCheckpointDir("checked_dfs")

val in_file = "sample.txt"
val stratified = true
val wsize = 7
val ngram = 3
val minF = 2

val windUdf = udf{s: String => s.sliding(ngram).toList.sliding(wsize).toList}
val get_mid = udf{s: Seq[String] => s(s.size/2)}
val rm_punct = udf{s: String => s.replaceAll("""([\p{Punct}|¿|\?|¡|!]|\p{C}|\b\p{IsLetter}{1,2}\b)\s*""", "")}

// Read and remove punctuation
var df = spark.read.text(in_file)
                    .withColumn("value", rm_punct($"value"))

// Creating windows and explode them, and get the center word into $"word" 
df = df.withColumn("char_nGrams", windUdf('value))
        .withColumn("word_winds", explode($"char_nGrams"))
        .withColumn("word", get_mid('word_winds))
val indexer = new StringIndexer().setInputCol("word")
                                    .setOutputCol("label")
df = indexer.fit(df).transform(df)

val hashingTF = new HashingTF().setInputCol("word_winds")
                                .setOutputCol("freqFeatures")
df = hashingTF.transform(df)
val idf = new IDF().setInputCol("freqFeatures")
                    .setOutputCol("features")
df = idf.fit(df).transform(df)
// Remove word whose freq is less than minF
var counts = df.groupBy("label").count
                                .filter(col("count") > minF)
                                .orderBy(desc("count"))
                                .withColumn("id", monotonically_increasing_id())
var filtro = df.groupBy("label").count.filter(col("count") <= minF)
df = df.join(filtro, Seq("label"), "leftanti")
var dfs = if(stratified){
// Create stratified sample 'dfs'
        var revs = counts.orderBy(asc("count")).select("count")
                                                .withColumn("id", monotonically_increasing_id())
        revs = revs.withColumnRenamed("count", "ascc")
// Weigh the labels (linearly) inversely ("ascc") proportional NORMALIZED weights to word ferquency

        counts = counts.join(revs, Seq("id"), "inner").withColumn("weight", col("ascc")/df.count)
        val minn = counts.select("weight").agg(min("weight")).first.getDouble(0) - 0.01
        val maxx = counts.select("weight").agg(max("weight")).first.getDouble(0) - 0.01
        counts = counts.withColumn("weight_n", (col("weight") - minn) / (maxx - minn))
        counts = counts.withColumn("weight_n", when(col("weight_n") > 1.0, 1.0)
                       .otherwise(col("weight_n")))
        var fractions = counts.select("label", "weight_n").rdd.map(x => (x(0), x(1)
                                .asInstanceOf[scala.Double])).collectAsMap.toMap
        df.stat.sampleBy("label", fractions, 36L).select("features", "word_winds", "word", "label")
        }else{ df }
dfs = dfs.checkpoint()

val lr = new LogisticRegression().setRegParam(0.01)

val Array(tr, ts) = dfs.randomSplit(Array(0.7, 0.3), seed = 12345)
val training = tr.select("word_winds", "features", "label", "word")
val test = ts.select("word_winds", "features", "label", "word")

val model = lr.fit(training)

def mapCode(m: scala.collection.Map[Any, String]) = udf( (s: Double) =>
                m.getOrElse(s, "")
        )
var labels = training.select("label", "word").distinct.rdd
                                             .map(x => (x(0), x(1).asInstanceOf[String]))
                                             .collectAsMap
var predictions = model.transform(test)
predictions = predictions.withColumn("pred_word", mapCode(labels)($"prediction"))
predictions.write.format("csv").save("spark_predictions")

spark.stop()
  }
}
Nacho
  • 792
  • 1
  • 5
  • 23

1 Answers1

0

Since your data is somewhat small it might help if you use coalesce before explode. Sometimes it can be inefficient to have too many nodes especially if there is a lot of shuffling in your code.

Like you said, it does seem like a lot of people have issues with explode. I looked at the link you provided but no one mentioned trying flatMap instead of explode.

fractalnature
  • 145
  • 2
  • 8