I have a tensor a
that I would like to first mask using mask
and then discard the remaining frames. To ensure the output tensor is of the correct shape, padding should fill in the remaining values at the end. I can assume there is only a single continuous sequence of True
's in each row of the mask.
e.g.
a = torch.arange(1,17).reshape(4,4)
# tensor([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12],
# [13, 14, 15, 16]])
mask = torch.tensor([[False, True, True, False],
[False, True, True, True],
[ True, False, False, False],
[ True, True, True, True]])
# desired output (assuming padding value is 0):
# tensor([[ 2, 3, 0, 0],
# [ 6, 7, 8, 0],
# [ 9, 0, 0, 0],
# [13, 14, 15, 16]])
I can achieve the desired output by applying torch.masked_select
followed by torch.nn.functional.pad
on each row in a loop but I am struggling to think of a way to do this more efficiently in batches.
I have also looked into starting by using torch.roll
and zeroing after appropriate indexes, but this function can only be applied across an entire dimension and not a custom amount of roll per row.