0

The background to the question turned out to be a bit lengthy than I intended it to be. I just wanted to provide a full picture with this thing.

I'm trying to write a custom & generic quantizer/bucketizer that buckets the data based on sampling to create bucket boundaries rather than use pre-defined strategies (like hashing etc). This approach enables better distribution of the data, and also the entire data set is sorted, for efficient readability (e.g. ORC and parquet)

Here's what I have so far:

// Need for Serializable: 
// https://medium.com/swlh/spark-serialization-errors-e0eebcf0f6e6

// Note: When a Object with String is created, then Serializable is not needed
// But when case classes are used, we get a Task not serializable error.
// The "extends Serializable" fixes that

class Quantizer[T <% Ordered[T]](df: DataFrame, bucketCols: Seq[String], numBins:Int = 200)(implicit encoder: Encoder[T]) extends Serializable {
  
  private var bins: Seq[T] = generateBins

  private def generateBins: Seq[T] = {
    val cols = bucketCols.map(col)
    val reccnt = df.count
    if(reccnt>0){
      val sampratio = if ( (numBins*1000.0) > reccnt ) 1 else (numBins*1000.0)/reccnt
      val samp = df.select(cols: _*).sample(sampratio).cache
      val stride = samp.count / (numBins-1)
      val prcnts = samp.withColumn("rn",row_number().over(Window.orderBy(cols: _*))).filter((col("rn") % stride)===0)
      prcnts.select(cols: _*).orderBy(cols: _*).as[T].collect
    } else {
      Seq[T]()
    }
  }
  
  def getBins: Seq[T] = bins
  
  def lowerBound(x: T) = {
    var l = 0;
    var h = bins.length;
    while (l < h){
      val m = (l+h)/2;
      if(x <= bins(m)){
        h = m;
      } else {
        l = m + 1;
      }
    }
    l
  }
  
  // alias for lowerBound
  // because the behavior is the same as the SQL function width_bucket!
  // except that this will work not just for Double datatypes
  def widthBucket = lowerBound _
  
}

To use this with a single column (say String datatype), I would do:

// create a quantizer
val q = new Quantizer[String](df, Seq("customer_id"), 1000)

// create a UDF
val customerWidthBucketUDF = udf((x: String) => q.widthBucket(x))
spark.udf.register("customer_width_bucket", customerWidthBucketUDF)

// Now we can use the UDF to get the bucket id for each row and partition and save it appropriately. e.g.
// df.withColumn("bkt", customerWidthBucket(col("customer_id"))
//   .repartition(col("bkt"))
//   .sortWithinPartitions(...)
//   .write
//   .save(...)
//

Now, if I want to use the same with a slightly complex data type, I need to first define a case class, and implement Ordered

// https://stackoverflow.com/a/19348339/8098748

// case class does not have a default Ordering method unlike Int, String etc.
// Hence, the compare method needs to be implemented

import java.math.BigDecimal

case class e(customer_id: String, order_id: String, tim: BigDecimal) extends Ordered[e] {
  import scala.math.Ordered.orderingToOrdered
  
  def compare(that: e): Int = (this.customer_id, this.order_id, this.tim) compare (that.customer_id, that.order_id, that.tim)
}

Using this is pretty similar to the above:

val q = new Quantizer[e](df, Seq("customer_id", "order_id", "tim"), 5000)

val myComplexWidthBucketUDF = udf((x: e) => q.widthBucket(x))
spark.udf.register("my_width_bucket", myComplexWidthBucketUDF)

All this works fine, but I don't really like the interface so much because:

  1. The case class is purely internal (to give the ability to do < in lowerBound)
  2. However, the user is forced to provide the required case class

Now for the question: is it is possible to do this without asking the user for a case class? We have all the details of the schema in the DataFrame anyway.

Thanks for reading!

mrbrahman
  • 455
  • 5
  • 18

0 Answers0