7

I am trying to create a user defined aggregate function which I can call from python. I tried to follow the answer to this question. I basically implemented the following (taken from here):

package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;

public class MySum extends UserDefinedAggregateFunction {
    private StructType _inputDataType;
    private StructType _bufferSchema;
    private DataType _returnDataType;

    public MySum() {
        List<StructField> inputFields = new ArrayList<StructField>();
        inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
        _inputDataType = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<StructField>();
        bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
        _bufferSchema = DataTypes.createStructType(bufferFields);

        _returnDataType = DataTypes.DoubleType;
    }

    @Override public StructType inputSchema() {
        return _inputDataType;
    }

    @Override public StructType bufferSchema() {
        return _bufferSchema;
    }

    @Override public DataType dataType() {
        return _returnDataType;
    }

    @Override public boolean deterministic() {
        return true;
    }

    @Override public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, null);
    }

    @Override public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)) {
            if (buffer.isNullAt(0)) {
                buffer.update(0, input.getDouble(0));
            } else {
                Double newValue = input.getDouble(0) + buffer.getDouble(0);
                buffer.update(0, newValue);
            }
        }
    }

    @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        if (!buffer2.isNullAt(0)) {
            if (buffer1.isNullAt(0)) {
                buffer1.update(0, buffer2.getDouble(0));
            } else {
                Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
                buffer1.update(0, newValue);
            }
        }
    }

    @Override public Object evaluate(Row buffer) {
        if (buffer.isNullAt(0)) {
            return null;
        } else {
            return buffer.getDouble(0);
        }
    }
}

I then compiled it with all dependencies and run pyspark with --jars myjar.jar

In pyspark I did:

df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row

def myCol(col):
    _f = sc._jvm.com.blu.bla.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))

I got the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))

<ipython-input-22-afcb8884e1db> in myCol(col)
      4 def myCol(col):
      5     _f = sc._jvm.com.blu.bla.MySum.apply
----> 6     return Column(_f(_to_seq(sc,[col], _to_java_column)))

TypeError: 'JavaPackage' object is not callable

I also tried adding --driver-class-path to the pyspark call but got the same result.

Also tried to access the java class through java import:

from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
    _f = jvm.bla.blu.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))

Also Tried to simply create the class (as suggested here):

a = jvm.com.bla.blu.MySum()

All are getting the same error message.

I can't seem to figure out what the problem is.

Community
  • 1
  • 1
Assaf Mendelson
  • 12,701
  • 5
  • 47
  • 56

1 Answers1

5

So it seems the main issue was that all of the options to add the jar (--jars, driver class path, SPARK_CLASSPATH) do not work properly if giving a relative path. THis is probably because of issues with the working directory inside ipython as opposed to where I ran pyspark.

Once I changed this to absolute path, it works (Haven't tested it on a cluster yet but at least it works on a local installation).

Also, I am not sure if this is a bug also in the answer here as that answer uses a scala implementation, however in the java implementation I needed to do

def myCol(col):
    _f = sc._jvm.com.blu.bla.MySum().apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))

This is probably not really efficient as it creates _f each time, instead I should probably define _f outside the function (again, this would require testing on the cluster) but at least now it provides the correct functional answer

Community
  • 1
  • 1
Assaf Mendelson
  • 12,701
  • 5
  • 47
  • 56
  • one last thing, for future reference this was tested on spark 1.6.0 local (single node) installation – Assaf Mendelson Mar 08 '16 at 17:00
  • 2
    Tested this on a cluster and it works. Used --jars AND --driver-class-path together (apparently --jars does not set the classpath to the driver) – Assaf Mendelson Mar 09 '16 at 12:38
  • Hi! I am trying to do a similar thing but I am getting the same error, I am not able to understand it, could you provide a bit of insight into it, as to how to run it. – Arnab Jun 24 '16 at 05:24
  • lets assume your jar is located at /home/a/a.jar. What I would do is run it with --jars /home/a/a.jar --driver-class-path /home/a/a.jar and it was fine – Assaf Mendelson Jun 26 '16 at 07:27
  • Hi Assaf, I get the same error. I pass two jars (comma separated). then i tried putting them in one directory and i get the error my my supporting jar (one is the main jar which uses the other one) has no main class. how do i get around that? – Anuja Khemka Sep 12 '17 at 20:31