I'm having trouble writing a custom collate_fn
function for the PyTorch DataLoader
class. I need the custom function because my inputs have different dimensions.
I'm currently trying to write the baseline implementation of the Stanford MURA paper. The dataset has a set of labeled studies. A study may contain more than one image. I created a custom Dataset
class that stacks these multiple images using torch.stack
.
The stacked tensor is then provided as input to the model and the list of outputs is averaged to obtain a single output. This implementation works fine with DataLoader
when batch_size=1
. However, when I try to set the batch_size
to 8, as is the case in the original paper, the DataLoader
fails since it uses torch.stack
to stack the batch and the inputs in my batch have variable dimensions (since each study can have multiple number of images).
In order to fix this, I tried to implement my custom collate_fn
function.
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Then in my training epoch loop, I loop through each batch like this:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
However, this does not give me any improved timings on the epoch and still takes as long as it did with a batch size of only 1. I've also tried setting the batch size to 32 and that does not improve the timings either.
Am I doing something wrong? Is there a better approach to this?