5

I'm trying to implement a gaussian-like blurring of a 3D volume in pytorch. I can do a 2D blur of a 2D image by convolving with a 2D gaussian kernel easy enough, and the same approach seems to work for 3D with a 3D gaussian kernel. However, it is very slow in 3D (especially with larger sigmas/kernel sizes). I understand this can also be done instead by convolving 3 times with the 2D kernel which should be much faster, but I can't get this to work. My test case is below.

import torch
import torch.nn.functional as F

VOL_SIZE = 21


def make_gaussian_kernel(sigma):
    ks = int(sigma * 5)
    if ks % 2 == 0:
        ks += 1
    ts = torch.linspace(-ks // 2, ks // 2 + 1, ks)
    gauss = torch.exp((-(ts / sigma)**2 / 2))
    kernel = gauss / gauss.sum()

    return kernel


def test_3d_gaussian_blur(blur_sigma=2):
    # Make a test volume
    vol = torch.zeros([VOL_SIZE] * 3)
    vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = make_gaussian_kernel(blur_sigma)
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 2D convolution
    vol_in = vol.reshape(1, *vol.shape)
    k2d = torch.einsum('i,j->ij', k, k)
    k2d = k2d / k2d.sum()
    k2d = k2d.expand(VOL_SIZE, 1, *k2d.shape)
    for i in range(3):
        vol_in = vol_in.permute(0, 3, 1, 2)
        vol_in = F.conv2d(vol_in, k2d, stride=1, padding=len(k) // 2, groups=VOL_SIZE)
    vol_3d_sep = vol_in

    torch.allclose(vol_3d, vol_3d_sep)  # --> False

Any help would be very much appreciated!

lopsided
  • 2,370
  • 6
  • 28
  • 40
  • Actually you can separate a 3d (isotropic) gaussian kernel into three 1d kernels. (Separating into 2d kernels works in theroy but there is no reason to do so!) – flawr May 21 '21 at 14:01
  • Yes, I guessed as much, but in the 2d case I can treat the 3rd dimension as channels and use grouped 2d convolutions (at least, I'm trying). In the 1d case I have 2 spatial dimensions so don't know how to implement it properly in pytorch. Some reshaping probably? – lopsided May 21 '21 at 14:32
  • In any case you can use the usual conv3d function but with an `k x 1 x 1` kernel. – flawr May 21 '21 at 14:49

2 Answers2

3

You theoreticaly can compute the 3d-gaussian convolution using three 2d-convolutions, but that would mean you have to reduce the size of the 2d-kernel, as you're effectively convolving in each direction twice.

But computationally more efficient (and what you usually want) is a separation into 1d-kernels. I changed the second part of your function to implement this. (And I must say I really liked your permutation-based appraoch!) Since you're using a 3d volume you can't really use the conv2d or conv1d functions well, so the best thing is really just using conv3d even if you're just computing 1d-convolutions.

Note that allclose uses a threshold of 1e-8 which we do not reach with this method, probably due to cancellation errors.

def test_3d_gaussian_blur(blur_sigma=2):
    # Make a test volume
    vol = torch.randn([VOL_SIZE] * 3) # using something other than zeros
    vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = make_gaussian_kernel(blur_sigma)
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 1D convolution
    vol_in = vol[None, None, ...]
    # k2d = torch.einsum('i,j->ij', k, k)
    # k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
    # print(f'{k2d.sum()=}')
    k1d = k[None, None, :, None, None]
    for i in range(3):
        vol_in = vol_in.permute(0, 1, 4, 2, 3)
        vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
    vol_3d_sep = vol_in
    print((vol_3d- vol_3d_sep).abs().max()) # something ~1e-7
    print(torch.allclose(vol_3d, vol_3d_sep)) # allclose checks if it is around 1e-8

Addendum: If you really want to abuse conv2d to process the volumes you can try

# separate 3d kernel into 1d + 2d
vol_in = vol[None, None, ...]
k2d = torch.einsum('i,j->ij', k, k)
k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
# k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
# print(f'{k2d.sum()=}')
k1d = k[None, None, :, None, None]
vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
vol_in = vol_in[0, ...]
# abuse conv2d-groups argument for volume dimension, works only for 1 channel volumes
vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
vol_3d_sep = vol_in

Or using exclusively conv2d you could do:

# separate 3d kernel into 1d + 2d
vol_in = vol[None,  ...]
# 1d kernel
k1d = k[None, None, :,  None]
k1d = k1d.expand(VOL_SIZE, 1, len(k), 1)
# 2d kernel
k2d = torch.einsum('i,j->ij', k, k)
k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
vol_in = vol_in.permute(0, 2, 1, 3)
vol_in = F.conv2d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0), groups=VOL_SIZE)
vol_in = vol_in.permute(0, 2, 1, 3)
vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
vol_3d_sep = vol_in

These should still be faster than three consecutive 2d convolutions.

flawr
  • 10,814
  • 3
  • 41
  • 71
  • Looks great thanks! It even passes the assert using zeros+1 :) On my actual problem this version gives about 20-25% speedup vs the 3d conv kernel approach. Although I was hoping to see a little more than that given the speed of 2d convolutions on the gpu. Would still like to get the 2d conv version working for a full comparison, but this is certainly a big improvement. – lopsided May 21 '21 at 15:43
  • @lopsided Do you have any reason to believe that the 2d convolutions will be faster? I would expect them (at least on the cpu to take (length of k)-times longer). And in general depending on the hardware and on the size fo the kernel having one 3d convolution might still be faster (or conversely, for very large kernels the 3x1d approach will definitely be a lot faster). Have you tested it with larger kernels/inputs? – flawr May 21 '21 at 15:47
  • I mean what you could do is just a 2d-convolution followed by an 1d convolution, but three 2d convolutions just don't seem to make much sense in my book. Or is there a particular reason you want to use three 2d convolutions? – flawr May 21 '21 at 15:54
  • Certainly the 1d approach should be the most efficient I agree. I haven't got any benchmarks but working with batches of 2d images (~64*200^2) flies through dozens of convolutions but here 1x100^3 with 3 "1d" convolutions seems to take much longer. Granted the kernel size is larger for the blur (~11) than I use for the cnns (~3x3) but still I feel there is more to be had. I'm probably wrong or losing time elsewhere though... – lopsided May 21 '21 at 16:00
  • The kernel size makes a HUGE difference in performance! Convolving an `n x n` image with a `k x k` kernel is an `O(n^2 * k^2)` operation. One thing you can try to speed up the computations is setting `torch.backends.cudnn.benchmark = True`. I added two snippets that "abuse" `conv2d` to perform the blurring on a volume, and I'm pretty sure they are better than using *three* 2d convolutions. But you have to keep in mind that these only work for 1-channel images. – flawr May 21 '21 at 16:15
  • Fantastic, a million thanks for these, great answer! – lopsided May 26 '21 at 16:30
1

To anyone who finds this question. Previous best answer still uses F.conv3d to do the job. In my case it was faster to rewrite using F.conv1d, making this convolution truly separated into 1d.

import torch
import torch.nn.functional as F

import cv2

VOL_SIZE = (10, 20, 30)
KS = 5
def test_3d_gaussian_blur(ks=5, blur_sigma=2):
    # Make a test volume
    vol = torch.randn(VOL_SIZE) # using something other than zeros

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = torch.from_numpy(cv2.getGaussianKernel(ks, blur_sigma)).squeeze().float()
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 1D convolution
    k1d = k.view(1, 1, -1)
    for _ in range(3):
        vol = F.conv1d(vol.reshape(-1, 1, vol.size(2)), k1d, padding=ks // 2).view(*vol.shape)
        vol = vol.permute(2, 0, 1)
    print((vol_3d- vol).abs().max()) # something ~1e-7
    print(torch.allclose(vol_3d, vol, atol=1e-6))


test_3d_gaussian_blur()