The background to the question turned out to be a bit lengthy than I intended it to be. I just wanted to provide a full picture with this thing.
I'm trying to write a custom & generic quantizer/bucketizer that buckets the data based on sampling to create bucket boundaries rather than use pre-defined strategies (like hashing etc). This approach enables better distribution of the data, and also the entire data set is sorted, for efficient readability (e.g. ORC and parquet)
Here's what I have so far:
// Need for Serializable:
// https://medium.com/swlh/spark-serialization-errors-e0eebcf0f6e6
// Note: When a Object with String is created, then Serializable is not needed
// But when case classes are used, we get a Task not serializable error.
// The "extends Serializable" fixes that
class Quantizer[T <% Ordered[T]](df: DataFrame, bucketCols: Seq[String], numBins:Int = 200)(implicit encoder: Encoder[T]) extends Serializable {
private var bins: Seq[T] = generateBins
private def generateBins: Seq[T] = {
val cols = bucketCols.map(col)
val reccnt = df.count
if(reccnt>0){
val sampratio = if ( (numBins*1000.0) > reccnt ) 1 else (numBins*1000.0)/reccnt
val samp = df.select(cols: _*).sample(sampratio).cache
val stride = samp.count / (numBins-1)
val prcnts = samp.withColumn("rn",row_number().over(Window.orderBy(cols: _*))).filter((col("rn") % stride)===0)
prcnts.select(cols: _*).orderBy(cols: _*).as[T].collect
} else {
Seq[T]()
}
}
def getBins: Seq[T] = bins
def lowerBound(x: T) = {
var l = 0;
var h = bins.length;
while (l < h){
val m = (l+h)/2;
if(x <= bins(m)){
h = m;
} else {
l = m + 1;
}
}
l
}
// alias for lowerBound
// because the behavior is the same as the SQL function width_bucket!
// except that this will work not just for Double datatypes
def widthBucket = lowerBound _
}
To use this with a single column (say String datatype), I would do:
// create a quantizer
val q = new Quantizer[String](df, Seq("customer_id"), 1000)
// create a UDF
val customerWidthBucketUDF = udf((x: String) => q.widthBucket(x))
spark.udf.register("customer_width_bucket", customerWidthBucketUDF)
// Now we can use the UDF to get the bucket id for each row and partition and save it appropriately. e.g.
// df.withColumn("bkt", customerWidthBucket(col("customer_id"))
// .repartition(col("bkt"))
// .sortWithinPartitions(...)
// .write
// .save(...)
//
Now, if I want to use the same with a slightly complex data type, I need to first define a case class, and implement Ordered
// https://stackoverflow.com/a/19348339/8098748
// case class does not have a default Ordering method unlike Int, String etc.
// Hence, the compare method needs to be implemented
import java.math.BigDecimal
case class e(customer_id: String, order_id: String, tim: BigDecimal) extends Ordered[e] {
import scala.math.Ordered.orderingToOrdered
def compare(that: e): Int = (this.customer_id, this.order_id, this.tim) compare (that.customer_id, that.order_id, that.tim)
}
Using this is pretty similar to the above:
val q = new Quantizer[e](df, Seq("customer_id", "order_id", "tim"), 5000)
val myComplexWidthBucketUDF = udf((x: e) => q.widthBucket(x))
spark.udf.register("my_width_bucket", myComplexWidthBucketUDF)
All this works fine, but I don't really like the interface so much because:
- The case class is purely internal (to give the ability to do < in lowerBound)
- However, the user is forced to provide the required case class
Now for the question: is it is possible to do this without asking the user for a case class? We have all the details of the schema in the DataFrame anyway.
Thanks for reading!