I'm trying to implement a gaussian-like blurring of a 3D volume in pytorch. I can do a 2D blur of a 2D image by convolving with a 2D gaussian kernel easy enough, and the same approach seems to work for 3D with a 3D gaussian kernel. However, it is very slow in 3D (especially with larger sigmas/kernel sizes). I understand this can also be done instead by convolving 3 times with the 2D kernel which should be much faster, but I can't get this to work. My test case is below.
import torch
import torch.nn.functional as F
VOL_SIZE = 21
def make_gaussian_kernel(sigma):
ks = int(sigma * 5)
if ks % 2 == 0:
ks += 1
ts = torch.linspace(-ks // 2, ks // 2 + 1, ks)
gauss = torch.exp((-(ts / sigma)**2 / 2))
kernel = gauss / gauss.sum()
return kernel
def test_3d_gaussian_blur(blur_sigma=2):
# Make a test volume
vol = torch.zeros([VOL_SIZE] * 3)
vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1
# 3D convolution
vol_in = vol.reshape(1, 1, *vol.shape)
k = make_gaussian_kernel(blur_sigma)
k3d = torch.einsum('i,j,k->ijk', k, k, k)
k3d = k3d / k3d.sum()
vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)
# Separable 2D convolution
vol_in = vol.reshape(1, *vol.shape)
k2d = torch.einsum('i,j->ij', k, k)
k2d = k2d / k2d.sum()
k2d = k2d.expand(VOL_SIZE, 1, *k2d.shape)
for i in range(3):
vol_in = vol_in.permute(0, 3, 1, 2)
vol_in = F.conv2d(vol_in, k2d, stride=1, padding=len(k) // 2, groups=VOL_SIZE)
vol_3d_sep = vol_in
torch.allclose(vol_3d, vol_3d_sep) # --> False
Any help would be very much appreciated!