5

Following Tensorflow's best practices for performance, I am using NCHW data format, but I am not sure about the filter shape to be used in tensorflow.nn.conv2d.

The doc says to use [filter_height, filter_width, in_channels, out_channels] for NHWC format, but is not clear about what to do with NCHW.

Should the same shape be used ?

jul
  • 36,404
  • 64
  • 191
  • 318

1 Answers1

0

Using the same filter shape should work. The only change to the function arguments is the stride. As an example let's say you wanted your architecture to work with both formats, which is also recommended:

# input -> Tensor in NCHW format
if use_nchw:
    result = tf.nn.conv2d(
        input=input,
        filter=filter,
        strides=[1, 1, stride, stride],
        data_format='NCHW')
else:
    input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC

    result = tf.nn.conv2d(
        input=input_t,
        filter=filter,
        strides=[1, stride, stride, 1])

    result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW  
Era
  • 19
  • 1