I think there is an error in line 53 of the following code:
It says:
return_images = torch.cat(return_images, 0) # collect all the images and return
What do you think, what should it read correctly? After having gone throught the code, I am unfortunately not sure what this specific line is supposed to do, but I think I understand the rest until line 52.