3

I have a CNN model which was trained in Pytorch based on the data format N(batch) x C(channel) x H(height) x W(width). I saved the pre-trained model as model.pth. Afterward, I converted the pre-trained model from model.pth -> model.onnx by using existing function:

torch.onnx.export(model, dummy_input, "model.onnx")

And then, I converted this model.onnx -> model.pb by the module below:

import onnx
from onnx_tf.backend import prepare 

model_onnx = onnx.load('model.onnx')
tf_rep = prepare(model_onnx)
tf_rep.export_graph('model.pb')

The problem is: I want to utilize this model.pb on a CPU device, which needs a NHWC data format. However, my model is based on NCHW data format. Is there any method that can convert the data format of this model.pb from NCHW into NHWC?

MY C
  • 31
  • 4

2 Answers2

1

Short answer, you are in a tough spot.

Long answer, it's difficult yet possible. What makes your problem difficult is your graph is already trained. It is inefficient, yet easier to convert NCHW->NHWC while you create the training graph. See similar answer here and here.

Now to your answer, you'll have to overload conv2D operator with custom convolution operator. Here is a pseudo code to get started.

   tensor Conv2D(X, W, B) {
     int perm[] = {0, 3, 1, 2};
     X = transposeTensor(X, perm);
     W = transposeTensor(W, perm);
     Y = Conv2D_orig(X, W, B, ...) ;
     perm = {0, 2, 3, 1};
     return transposeTensor(Y, perm);
   }
0

Can you just transpose the input tensors? e.g. input.transpose(1,2).transpose(2,3) ?

>>> torch.randn( (3,3,3,3), names=['n','c','h','w']).transpose(1,2).transpose(2,3).names
('n', 'h', 'w', 'c')
nairbv
  • 4,045
  • 1
  • 24
  • 26
  • Hi, do you mean before training in Pytorch? Thanks for the answer, any advice helps! – MY C Feb 04 '20 at 16:14
  • yes... on any inputs, whether training data and evaluation data. Any inputs that have an incompatible dimensions. – nairbv Feb 05 '20 at 14:57
  • Doesn't torch.nn only use channel first as data format i.e. NCHW in this case? Please correct me if I'm wrong. Thanks. – MY C Feb 05 '20 at 20:42
  • 1
    Yes that's right. Your input data can be converted but I'm not sure what model you're trying to run. You can follow work to support channel last here: https://github.com/pytorch/pytorch/issues/28619 – nairbv Feb 09 '20 at 13:38
  • I'm using customized model. I'll try to follow the github issue you mentioned. Thanks for helping again. – MY C Feb 10 '20 at 23:07