I have 4D tensors for YUV image data using PyTorch -
a = torch.rand(N, C, H, W)
C here is 3 since there are YUV - 3 channels. I want to send each of the channels separately into a loss function which requires a 4D tensor to be inputted. How do I go about this? My approach -
y = a[:,0,:,:]
u = a[:,1,:,:]
v = a[:,2,:,:]
But by doing this, 4D tensors are being converted into 3D tensors causing an error during input to the loss function. How do I retain the dimensionality of the tensors?
y = torch.reshape(y, (N, 1, H, W))
u = torch.reshape(u, (N, 1, H, W))
v = torch.reshape(v, (N, 1, H, W))
torch.reshape
is there but it involves lot of hard-coding. Is there a better way to get this done?