0

From the accepted answer in this question,

given the following input and kernel matrices, the output of tf.nn.conv2d is

[[14 6] [6 12]]

which makes sense. However, when I make the input and kernel matrices have 3-channels each (by repeating each original matrix), and run the same code:

# the previous input
i_grey = np.array([
    [4, 3, 1, 0],
    [2, 1, 0, 1],
    [1, 2, 4, 1],
    [3, 1, 0, 2]
])

# copy to 3-dimensions
i_rgb = np.repeat( np.expand_dims(i_grey, axis=0), 3, axis=0 )

# convert to tensor
i_rgb = tf.constant(i_rgb, dtype=tf.float32)

# make kernel depth match input; same process as input
k = np.array([
    [1, 0, 1],
    [2, 1, 0],
    [0, 0, 1]
])

k_rgb = np.repeat( np.expand_dims(k, axis=0), 3, axis=0 )

# convert to tensor
k_rgb = tf.constant(k_rgb, dtype=tf.float32)

here's what my input and kernel matrices look like at this point

# reshape input to format: [batch, in_height, in_width, in_channels]
image_rgb  = tf.reshape(i_rgb, [1, 4, 4, 3])

# reshape kernel to format: [filter_height, filter_width, in_channels, out_channels]
kernel_rgb = tf.reshape(k_rgb, [3, 3, 3, 1])

conv_rgb = tf.squeeze( tf.nn.conv2d(image_rgb, kernel_rgb, [1,1,1,1], "VALID") )
with tf.Session() as sess:
    conv_result = sess.run(conv_rgb)
    print(conv_result)

I get the final output:

[[35. 15.] [35. 26.]]

But I was expecting the original output*3:

[[42. 18.] [18. 36.]]

because from my understanding, each channel of the kernel is convolved with each channel of the input, and the resultant matrices are summed to get the final output.

Am I missing something from this process or the tensorflow implementation?

kym
  • 818
  • 7
  • 12

1 Answers1

0

Reshape is a tricky function. It will produce you the shape you want, but can easily ground things together. In cases like yours, one should avoid using reshape by all means.

In that particular case instead, it is better to duplicate the arrays along the new axis. When using [batch, in_height, in_width, in_channels] channels is the last dimension and it should be used in repeat() function. Next code should better reflect the logic behind it:

i_grey = np.expand_dims(i_grey, axis=0) # add batch dim
i_grey = np.expand_dims(i_grey, axis=3) # add channel dim
i_rgb = np.repeat(i_grey, 3, axis=3 )   # duplicate along channels dim

And likewise with filters:

k = np.expand_dims(k, axis=2) # input channels dim
k = np.expand_dims(k, axis=3) # output channels dim
k_rgb = np.repeat(k, 3, axis=2) # duplicate along the input channels dim
y.selivonchyk
  • 8,987
  • 8
  • 54
  • 77