It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).
There are two main ways:
- Resize input to the nearest feasible size.
- Pad the input to the next larger feasible size.
I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.
My favorite code snippet for this task (symmetric padding for height/width):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
A test snippet:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
Output:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33