2

I have a tensor a that I would like to first mask using mask and then discard the remaining frames. To ensure the output tensor is of the correct shape, padding should fill in the remaining values at the end. I can assume there is only a single continuous sequence of True's in each row of the mask.

e.g.

a = torch.arange(1,17).reshape(4,4)
# tensor([[ 1,  2,  3,  4],
#         [ 5,  6,  7,  8],
#         [ 9, 10, 11, 12],
#         [13, 14, 15, 16]])

mask = torch.tensor([[False,  True,  True, False],
                     [False,  True,  True,  True],
                     [ True, False, False, False],
                     [ True,  True,  True,  True]])

# desired output (assuming padding value is 0):
# tensor([[ 2,  3,  0,  0],
#         [ 6,  7,  8,  0],
#         [ 9,  0,  0,  0],
#         [13, 14, 15, 16]])

I can achieve the desired output by applying torch.masked_select followed by torch.nn.functional.pad on each row in a loop but I am struggling to think of a way to do this more efficiently in batches.

I have also looked into starting by using torch.roll and zeroing after appropriate indexes, but this function can only be applied across an entire dimension and not a custom amount of roll per row.

Arya McCarthy
  • 8,554
  • 4
  • 34
  • 56
Seraf Fej
  • 209
  • 1
  • 9

1 Answers1

2

By applying torch.sort on the mask itself you can achieve the desired result. Indeed if your sort the boolean values you can manage to move the False values at the end of the stack, and let the True values at the beginning.

Do note this might vary depending on the sorting algorithm, there might be some shuffling for certain algorithms.... As @Seraf Fej pointed out: you can use the stable=True option on torch.stable such that the order of equivalent items is preserved.

Then use the indices of the sorting to gather the values on a with torch.gather. Finally, you will need to mask the resulting matrix to replace the discarded values with the appropriate padding.

>>> a
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

>>> mask
tensor([[False,  True,  True, False],
        [False,  True,  True,  True],
        [ True, False, False, False],
        [ True,  True,  True,  True]])

Sort the mask:

>>> values, indices = mask.sort(1, descending=True, stable=True)

>>> values
tensor([[ True,  True, False, False],
        [ True,  True,  True, False],
        [ True, False, False, False],
        [ True,  True,  True,  True]])

>>> indices
tensor([[1, 2, 0, 3],
        [1, 2, 3, 0],
        [0, 1, 2, 3],
        [0, 1, 2, 3]])

Gather from indices and mask with values:

>>> a.gather(1, indices)*values
tensor([[ 2,  3,  0,  0],
        [ 6,  7,  8,  0],
        [ 9,  0,  0,  0],
        [13, 14, 15, 16]])

You can easily extend to any padding value using torch.where:

>>> torch.where(values, a.gather(1, indices), -1)
tensor([[ 2,  3, -1, -1],
        [ 6,  7,  8, -1],
        [ 9, -1, -1, -1],
        [13, 14, 15, 16]])

Or using the inverse mask ~values, weighted by the padding value:

>>> a.gather(1, indices)*values -1*~values
tensor([[ 2,  3, -1, -1],
        [ 6,  7,  8, -1],
        [ 9, -1, -1, -1],
        [13, 14, 15, 16]])
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • 1
    Thanks, clever solution! I notice `torch.sort` takes the argument `stable=True` which preserves the order of equivalent elements. So I think using `values, indices = torch.sort(mask, dim=1, descending=True, stable=True)` should account for the potential shuffling issue you mentioned. – Seraf Fej Aug 02 '21 at 13:08