Suppose I have two tensors S
and T
defined as:
S = torch.rand((3,2,1))
T = torch.ones((3,2,1))
We can think of these as containing batches of tensors with shapes (2, 1)
. In this case, the batch size is 3
.
I want to concatenate all possible pairings between batches. A single concatenation of batches produces a tensor of shape (4, 1)
. And there are 3*3
combinations so ultimately, the resulting tensor C
must have a shape of (3, 3, 4, 1)
.
One solution is to do the following:
for i in range(S.shape[0]):
for j in range(T.shape[0]):
C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))
But the for loop doesn't scale well to large batch sizes. Is there a PyTorch command to do this?