You can manage to do something like that using tf.nn.separable_conv2d
, using the batch dimension as the separable channels and the actual input channels as batch dimension. I am not sure if it is going to be perform very well, though, as it involves several transpositions (which are not free in TensorFlow) and a convolution through a large number of channels, which is not really the optimized use case. Here is how it could work:
import tensorflow as tf
import numpy as np
import scipy.signal
# Expects imgs with shape (B, H, W, C) and filters with shape (B, H, W, 1)
def batch_conv(imgs, filters, strides, padding, rate=None):
imgs = tf.convert_to_tensor(imgs)
filters = tf.convert_to_tensor(filters)
b = tf.shape(imgs)[0]
imgs_t = tf.transpose(imgs, [3, 1, 2, 0])
filters_t = tf.transpose(filters, [1, 2, 0, 3])
strides = [strides[3], strides[1], strides[2], strides[0]]
# "do-nothing" pointwise filter
pointwise = tf.eye(b, batch_shape=[1, 1])
conv = tf.nn.separable_conv2d(imgs_t, filters_t, pointwise, strides, padding, rate)
return tf.transpose(conv, [3, 1, 2, 0])
# Slow, loop-based version using SciPy's correlate to check result
def batch_conv_np(imgs, filters, padding):
return np.stack(
[np.stack([scipy.signal.correlate2d(img[..., i], filter[..., 0], padding.lower())
for i in range(img.shape[-1])], axis=-1)
for img, filter in zip(imgs, filters)], axis=0)
# Make random input
np.random.seed(0)
imgs = np.random.rand(5, 20, 30, 3).astype(np.float32)
filters = np.random.rand(5, 20, 30, 1).astype(np.float32)
padding = 'SAME'
# Test
res_np = batch_conv_np(imgs, filters, padding)
with tf.Graph().as_default(), tf.Session() as sess:
res_tf = batch_conv(imgs, filters, [1, 1, 1, 1], padding)
res_tf_val = sess.run(res_tf)
print(np.allclose(res_np, res_tf_val))
# True