I have a trained PyTorch model that I would now like to export to Caffe2 using ONNX. This part seems fairly simple and well documented. However, I now want to "load" that model into a Java program in order to perform predictions within my program (a Flink streaming application). What is the best way to do this? I haven't been able to find any documentation on the website describing how to do this.
Asked
Active
Viewed 3,579 times
11
-
Hello! Do you have any updates or solutions with this problem? – Gilaztdinov Rustam Jan 17 '18 at 17:21
-
Unfortunately I don't. Right now I'm just using AWS Lambda in Python and calling like I would any other API from the program. Would love to hear how to load the model in Java. – igodfried Jan 23 '18 at 04:40
-
You could give the experimental Java bindings of CNTK a shot. As far as I've been able to test they work well for non-recurrent models. Recurrent models is a problem because the current bindings make it really hard to build sequences of observations. – Willem Meints Feb 14 '19 at 19:30
-
It's now possible to use the C++ API of PyTorch directly from Java using the JavaCPP Presets for Python, so you might not even have to worry about ONNX: https://github.com/bytedeco/javacpp-presets/tree/master/pytorch – Samuel Audet Apr 16 '21 at 01:43
1 Answers
4
Currently it's a bit tricky but there is a way. You will need to use JavaCPP:
- NGraph https://github.com/bytedeco/javacpp-presets/tree/master/ngraph
- ONNX https://github.com/bytedeco/javacpp-presets/tree/master/onnx
I will use single_relu.onnx as example:
//read ONNX
byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
ModelProto model = new ModelProto();
ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model
//preprocess model in any way you like (you can skip this step)
check_model(model);
InferShapes(model);
StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
Optimize(model, passes);
check_model(model);
ConvertVersion(model, 8);
BytePointer serialized = model.SerializeAsString();
System.out.println("model="+serialized.getString());
//prepare nGraph backend
Backend backend = Backend.create("CPU");
Shape shape = new Shape(new SizeTVector(1,2 ));
Tensor input =backend.create_tensor(f32(), shape);
Tensor output =backend.create_tensor(f32(), shape);
Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
Executable exec = backend.compile(ng_function);
exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));
//collect result to array
float[] r = new float[2];
FloatPointer p = new FloatPointer(r);
output.read(p, 0, r.length * 4);
p.get(r);
//print result
System.out.println("[");
for (int i = 0; i < shape.get(0); i++) {
System.out.print(" [");
for (int j = 0; j < shape.get(1); j++) {
System.out.print(r[i * (int)shape.get(1) + j] + " ");
}
System.out.println("]");
}
System.out.println("]");

alagris
- 1,838
- 16
- 31
-
Is this answer (by @alagris) still the best option for how to serve the model in Java? Seems like a very convoluted and low-level approach, not exactly what I would expect from a framework designed for cross-platform model sharing. – martin_wun Oct 07 '21 at 06:46