1

I have two tensors, both with batch size N of images and same resolution. I would like to convolve the first image in tensor 1 with the first image of tensor 2, second image of tensor 1 with tensor 2, and so on. I want the output to be a tensor with N images of the same size. I looked into using tf.nn.conv2d, but it seems like this command will take in a batch of N images and convolve them with a single filter.

I looked into examples like What does tf.nn.conv2d do in tensorflow? but they do not talk about multiple images and multiple filters.

user6360
  • 21
  • 5
  • You can use a for loop with [`scipy.signal.convolve2d`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve2d.html) – Chris Mueller Aug 28 '19 at 12:23
  • Do images and filters have multiple channels? – jdehesa Aug 28 '19 at 13:23
  • @ChrisMueller I thought of this but wouldn't it be computationally expensive during training? I also thought of reducing batch to size 1 but there are downsides to that – user6360 Aug 28 '19 at 20:03
  • @jdehesa right now I am trying with gray scale images but ideally I would like to use RGB for the input images and single channel for the filter. – user6360 Aug 28 '19 at 20:04

1 Answers1

0

You can manage to do something like that using tf.nn.separable_conv2d, using the batch dimension as the separable channels and the actual input channels as batch dimension. I am not sure if it is going to be perform very well, though, as it involves several transpositions (which are not free in TensorFlow) and a convolution through a large number of channels, which is not really the optimized use case. Here is how it could work:

import tensorflow as tf
import numpy as np
import scipy.signal

# Expects imgs with shape (B, H, W, C) and filters with shape (B, H, W, 1)
def batch_conv(imgs, filters, strides, padding, rate=None):
    imgs = tf.convert_to_tensor(imgs)
    filters = tf.convert_to_tensor(filters)
    b = tf.shape(imgs)[0]
    imgs_t = tf.transpose(imgs, [3, 1, 2, 0])
    filters_t = tf.transpose(filters, [1, 2, 0, 3])
    strides = [strides[3], strides[1], strides[2], strides[0]]
    # "do-nothing" pointwise filter
    pointwise = tf.eye(b, batch_shape=[1, 1])
    conv = tf.nn.separable_conv2d(imgs_t, filters_t, pointwise, strides, padding, rate)
    return tf.transpose(conv, [3, 1, 2, 0])

# Slow, loop-based version using SciPy's correlate to check result
def batch_conv_np(imgs, filters, padding):
    return np.stack(
        [np.stack([scipy.signal.correlate2d(img[..., i], filter[..., 0], padding.lower())
                   for i in range(img.shape[-1])], axis=-1)
         for img, filter in zip(imgs, filters)], axis=0)

# Make random input
np.random.seed(0)
imgs = np.random.rand(5, 20, 30, 3).astype(np.float32)
filters = np.random.rand(5, 20, 30, 1).astype(np.float32)
padding = 'SAME'
# Test
res_np = batch_conv_np(imgs, filters, padding)
with tf.Graph().as_default(), tf.Session() as sess:
    res_tf = batch_conv(imgs, filters, [1, 1, 1, 1], padding)
    res_tf_val = sess.run(res_tf)
print(np.allclose(res_np, res_tf_val))
# True
jdehesa
  • 58,456
  • 7
  • 77
  • 121