1

I have a input tensor which has zero padding at the start and then a sequence of values. So something like:

x = torch.tensor([[0, 2, 8, 12],
                  [0, 0, 6, 3]])

What I need is another tensor having same shape and retaining 0's for the padding and an increasing sequence for the rest of the numbers. So my output tensor should be:

y = ([[0, 1, 2, 3],
      [0, 0, 1, 2]])

I tried something like:

MAX_SEQ=4
seq_start = np.nonzero(x)
start = seq_start[0][0]
pos_id = torch.cat((torch.from_numpy(np.zeros(start, dtype=int)).to(device), torch.arange(1, MAX_SEQ-start+1).to(device)), 0)
print(pos_id)

This works if the tensor is 1 dimensional but needs additional logic to handle it for 2-D shape. This can be done as np.nonzeros returns a tuple and we could probably loop thru' those tuples updating a counter or something. However I am sure there must be a simple tensor operation which should do this in 1-2 lines of code and also perhaps more effectively.

Help appreciated

Allohvk
  • 915
  • 8
  • 14

2 Answers2

1

A possible solution in three small steps:

  1. Find the index of the first non zero element for each row. This can be done with a trick explained here (adapted here for non-binary tensors).

    > idx = torch.arange(x.shape[1], 0, -1)
    tensor([4, 3, 2, 1])
    
    > xbin = torch.where(x == 0, 0, 1)
    tensor([[0, 1, 1, 1],
            [0, 0, 1, 1]])
    
    > xbin*idx
    tensor([[0, 3, 2, 1],
            [0, 0, 2, 1]])
    
    > indices = torch.argmax(xbin*idx, dim=1, keepdim=True)
    tensor([[1],
            [2]])
    
  2. Create an arangement for the resulting tensor (without padding). This can be done by applying torch.repeat and torch.view on a torch.arange call:

    > rows, cols = x.shape
    > seq = torch.arange(1, cols+1).repeat(1, rows).view(-1, cols)
    tensor([[1, 2, 3, 4],
            [1, 2, 3, 4]])
    
  3. Lastly - here's the trick! - we substract the index of the first non-zero element with the arangement, for each row. Then we mask the padding values and replace them with zeros:

    > pos_id = seq - indices
    tensor([[ 0,  1,  2,  3],
            [-1,  0,  1,  2]])
    
    > mask = indices > seq - 1
    tensor([[ True, False, False, False],
            [ True,  True, False, False]])
    
    > pos_id[mask] = 0
    tensor([[0, 1, 2, 3],
            [0, 0, 1, 2]])
    
Ivan
  • 34,531
  • 8
  • 55
  • 100
0

Expanding Ivan's nice answer to include batch size as my model had that. This 'seems' to work. This is just for a reference in case more than 2D to be considered

x = torch.tensor([[[ 0,  0,  2,  3,  4,  5,  6,  7,  8,  9],
                [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],

               [[0, 0, 0, 0, 0, 0, 26, 27, 28, 29],
                [0, 31, 32, 33, 34, 35, 36, 37, 38, 39]],

               [[0, 0, 42, 43, 44, 45, 46, 47, 48, 49],
                [0, 0, 0, 53, 0, 55, 56, 57, 58, 59]]])

bs, rows, cols = x.shape
seq = torch.arange(1, cols+1).repeat(1, rows).repeat(1, bs).view(bs, rows, cols)

idx = torch.arange(x.shape[-1], 0, -1)
xbin = torch.where(x == 0, 0, 1)
indices = torch.argmax(xbin*idx, dim=2, keepdim=True)

pos_id = seq - indices
mask = indices > seq - 1
pos_id[mask] = 0
print(pos_id)
halfer
  • 19,824
  • 17
  • 99
  • 186
Allohvk
  • 915
  • 8
  • 14
  • Nice! Glad I could help. You can pull it off by calling `torch.repeat` once, with `torch.arange(1, cols+1).repeat(1, bs*rows).view(bs, rows, cols)`, but that's a minor detail ;) – Ivan Dec 27 '20 at 14:02