I'm trying to find a way to do this without for loops.
Say I have a multi-dimensional tensor t0
:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
This has shape: torch.Size([4, 10, 16])
I have another tensor labels
that is a batch of 5 random indices in the seq
dimension:
labels = torch.randint(0, seq, size=[bs, sample])
So this has shape torch.Size([4, 5])
. This is used to index the seq
dimension of t0
.
What I want to do is loop over the batch dimension doing gathers using labels
tensor. My brute force solution is this:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
Resulting in tensor t1
which has shape: torch.Size([4, 5, 16])
Is there a more idiomatic way of doing this in pytorch?