I am trying to create a user-defined aggregate function (UDAF) in Java using Apache Spark SQL that returns multiple arrays on completion. I have searched online and cannot find any examples or suggestions on how to do this.
I am able to return a single array, but cannot figure out how to get the data in the correct format in the evaluate() method for returning multiple arrays.
The UDAF does work as I can print out the arrays in the evaluate() method, I just can't figure out how to return those arrays to the calling code (which is shown below for reference).
UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
I have included the whole custom UDAF class below, but the key methods are the dataType() and evaluate methods(), which are shown first.
Any help or advice would be greatly appreciated. Thank you.
public class CustomUDAF extends UserDefinedAggregateFunction {
@Override
public DataType dataType() {
// TODO: Is this the correct way to return 2 arrays?
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public Object evaluate(Row buffer) {
// Data conversion
List<Long> longList = new ArrayList<Long>(buffer.getList(0));
List<Double> dataList = new ArrayList<Double>(buffer.getList(1));
// Processing of data (omitted)
// TODO: How to get data into format needed to return 2 arrays?
return dataList;
}
@Override
public StructType inputSchema() {
return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
}
@Override
public StructType bufferSchema() {
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new ArrayList<Long>());
buffer.update(1, new ArrayList<Double>());
}
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
longList.add(row.getLong(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
dataList.add(row.getDouble(1));
buffer.update(0, longList);
buffer.update(1, dataList);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
longList.addAll(buffer2.getList(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
dataList.addAll(buffer2.getList(1));
buffer1.update(0, longList);
buffer1.update(1, dataList);
}
@Override
public boolean deterministic() {
return true;
}
}
Update: Based on the answer by zero323 I was able to return two arrays using:
return new Tuple2<>(longArray, dataArray);
Getting the data out of this was a bit of a struggle but involved deconstructing the DataFrame to Java Lists and then building it back to a DataFrame.