11

I have a trained PyTorch model that I would now like to export to Caffe2 using ONNX. This part seems fairly simple and well documented. However, I now want to "load" that model into a Java program in order to perform predictions within my program (a Flink streaming application). What is the best way to do this? I haven't been able to find any documentation on the website describing how to do this.

igodfried
  • 877
  • 9
  • 22
  • Hello! Do you have any updates or solutions with this problem? – Gilaztdinov Rustam Jan 17 '18 at 17:21
  • Unfortunately I don't. Right now I'm just using AWS Lambda in Python and calling like I would any other API from the program. Would love to hear how to load the model in Java. – igodfried Jan 23 '18 at 04:40
  • You could give the experimental Java bindings of CNTK a shot. As far as I've been able to test they work well for non-recurrent models. Recurrent models is a problem because the current bindings make it really hard to build sequences of observations. – Willem Meints Feb 14 '19 at 19:30
  • It's now possible to use the C++ API of PyTorch directly from Java using the JavaCPP Presets for Python, so you might not even have to worry about ONNX: https://github.com/bytedeco/javacpp-presets/tree/master/pytorch – Samuel Audet Apr 16 '21 at 01:43

1 Answers1

4

Currently it's a bit tricky but there is a way. You will need to use JavaCPP:

I will use single_relu.onnx as example:

    //read ONNX
    byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
    ModelProto model = new ModelProto(); 
    ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model

    //preprocess model in any way you like (you can skip this step)
    check_model(model);
    InferShapes(model);
    StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
    Optimize(model, passes);
    check_model(model);
    ConvertVersion(model, 8);
    BytePointer serialized = model.SerializeAsString();
    System.out.println("model="+serialized.getString());

    //prepare nGraph backend
    Backend backend = Backend.create("CPU");
    Shape shape = new Shape(new SizeTVector(1,2 ));
    Tensor input =backend.create_tensor(f32(), shape);
    Tensor output =backend.create_tensor(f32(), shape);
    Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
    Executable exec = backend.compile(ng_function);
    exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));

    //collect result to array
    float[] r = new float[2];
    FloatPointer p = new FloatPointer(r);
    output.read(p, 0, r.length * 4);
    p.get(r);

    //print result
    System.out.println("[");
    for (int i = 0; i < shape.get(0); i++) {
        System.out.print(" [");
        for (int j = 0; j < shape.get(1); j++) {
            System.out.print(r[i * (int)shape.get(1) + j] + " ");
        }
        System.out.println("]");
    }
    System.out.println("]");
alagris
  • 1,838
  • 16
  • 31
  • Is this answer (by @alagris) still the best option for how to serve the model in Java? Seems like a very convoluted and low-level approach, not exactly what I would expect from a framework designed for cross-platform model sharing. – martin_wun Oct 07 '21 at 06:46