0

I want to do stratified sampling on my dataframe in Scala. My dataframe has only one column and I want to form a fraction map for it. I am able to do it in pyspark but it gives me error in Scala. Here is what I tried in Scala:

import org.apache.spark.sql.functions.{lit}
val fractions = pqdf.select("vin").distinct().withColumn("fraction", lit(0.001)).rdd.collect().toMap

It errors out saying:

Error:(25, 100) Cannot prove that org.apache.spark.sql.Row <:< (T, U). val fractions = pqdf.select("vin").distinct().withColumn("fraction", lit(0.001)).rdd.collect().toMap

How do I resolve it? I want to use the fraction map created above in .samplyBy method as one of the parameters

val sampled_df = pqdf.stat.sampleBy("vin", fractions, 10L)

This is what I tried in pyspark which works:

from pyspark.sql.functions import lit
fractions = df.select("VIN").distinct().withColumn("fraction", lit(0.001)).rdd.collectAsMap()
# fractions
sampled_df = df.stat.sampleBy("VIN", fractions, 10)

I am not sure how do I achieve same thing in Scala.

CodeHunter
  • 2,017
  • 2
  • 21
  • 47

1 Answers1

0

It gives you an error because DataFrame.rdd returns RDD[Row]. To make it work you need something that is convertible to Map, for example (if you want Map[String, Double]):

df.select("VIN").distinct()
  .select($"VIN".cast("string"))
  .withColumn("fraction", lit(0.001))
  .as[(String, Double)]
  .rdd
  .collectAsMap()

or

df.select("VIN").distinct()
  .withColumn("fraction", lit(0.001))
  .rdd
  .map { case Row(vin, fraction: Double) => (vin.toString, fraction) }
  .collectAsMap()
  • makes sense. Let me try it out at my end. Thanks much! – CodeHunter Jun 06 '18 at 22:00
  • it gives type mismatch error: `Error:(32, 32) overloaded method value sampleBy with alternatives: [T](col: String, fractions: java.util.Map[T,java.lang.Double], seed: Long)org.apache.spark.sql.DataFrame [T](col: String, fractions: scala.collection.immutable.Map[T,scala.Double], seed: Long)org.apache.spark.sql.DataFrame cannot be applied to (String, scala.collection.Map[String,scala.Double], Long) val sampled_df = pqdf.stat.sampleBy("vin", fractions, 10L)` – CodeHunter Jun 06 '18 at 22:03