This is just a guess, but are you by any chance processing each input image (or alternatively post-processing detections) of the batch separately inside of a for-loop? If yes, your behaviour might be due to how torch exports to ONNX, and you will need to modify your forward pass. Or, alternatively you can use torch.jit.script
.
Where forwad pass could go wrong
Check your model for anything that defines a dimension of a tensor that is interpreted as a python integer during export. Setting dynamic axes will try to use variable shapes for the corresponding tensors, but will be overridden by explicit constant ones.
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
bs, c, h, w = batch.shape
# bs is saved as a constant integer during export
for i in range(bs):
do_something()
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
# iterating over tensors is not supported for dynamic batch sizes
# ONNX model will iterate the same amount as in batch during export
for i in batch:
do_something()
Potential fixes
Use tensor.size
instead of tensor.shape
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
def forward(self, batch):
# This calls a function instead of getting an attribute,
# the variable will be dynamic
bs = batch.size(0)
for i in range(bs):
do_something()
Script parts of the model to preserve control flows and different input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
# Script parts of the forward pass, e.g. single functions
@torch.jit._script_if_tracing
def do_something(batch):
for i in batch:
do_something_else()
def forward(self, batch):
# function will be scripted, dynamic shapes preserved
do_something(batch)
Export the whole module as a ScriptModule
, preserving all control flows
and input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
script_module = torch.jit.script(model)
torch.onnx.export(
script_module,
...
)