I am trying to develop a user defined aggregate function that computes a linear regression on a row of numbers. I have successfully done a UDAF that calculates confidence intervals of means (with a lot trial and error and SO!).
Here's what actually runs for me already:
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}
case class RegressionData(intercept: Double, slope: Double)
class Regression {
import org.apache.commons.math3.stat.regression.SimpleRegression
def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }
def getRegression(data: List[Long]): RegressionData = {
val regression: SimpleRegression = new SimpleRegression()
data.view.zipWithIndex.foreach { d =>
regression.addData(d._2.toDouble, d._1.toDouble)
}
RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
}
}
class UDAFRegression extends UserDefinedAggregateFunction {
import java.util.ArrayList
def deterministic = true
def inputSchema: StructType =
new StructType().add("units", LongType)
def bufferSchema: StructType =
new StructType().add("buff", ArrayType(LongType))
def dataType: DataType =
new StructType()
.add("intercept", DoubleType)
.add("slope", DoubleType)
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, new ArrayList[Long]())
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer.getList(0))
longList.add(input.getLong(0));
buffer.update(0, longList);
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
longList.addAll(buffer2.getList(0))
buffer1.update(0, longList)
}
def evaluate(buffer: Row) = {
import scala.collection.JavaConverters._
val list = buffer.getList(0).asScala.toList
val regression = new Regression
regression.getRegression(list)
}
}
However the datasets do not come in order, which is obviously very important here. Hence instead of regression($"longValue")
I need to a second param regression($"longValue", $"created_day")
. created_day
is a sql.types.DateType
.
I am pretty confused by DataTypes, StructTypes and what-not and due to the lack of examples on the web, I got stuck w/ my trial and order attempts here.
What would my bufferSchema
look like?
Are those StructTypes overhead in my case? Wouldn't a (mutable) Map
just do? Is MapType
actually immutable and isn't this rather pointless to be a buffer type?
What would my inputSchema
look like?
Does this have to match the type I retrieve in update()
via in my case input.getLong(0)
?
Is there a standard way how to reset the buffer in initialize()
I have seen buffer.update(0, 0.0)
(when it contains Doubles, obviously), buffer(0) = new WhatEver()
and I think even buffer = Nil
. Does any of these make a difference?
How to update data?
The example above seems over complicated. I was expecting to be able to do sth. like buffer += input.getLong(0) -> input.getDate(1)
.
Can I expect to access the input this way
How to merge data?
Can I just leave the function block empty like
def merge(…) = {}
?
The challenge to sort that buffer in evaluate()
is sth. I should be able to figure out, although I am still interested in the most elegant ways of how you guys do this (in a fraction of the time).
Bonus question: What role does dataType
play?
I return a case class, not the StructType
as defined in dataType
which does not seem to be an issue. Or is it working since it happens to match my case class?