0

I once saw code segment using torch.reshape as following:

for name, value in samples.items():
    value_flat = torch.reshape(value, (-1,) + value.shape[2:]))

what does (-1,) + value.shape[2:]) means here? Here value is of type torch.tensor.

m0nhawk
  • 22,980
  • 9
  • 45
  • 73
user785099
  • 5,323
  • 10
  • 44
  • 62

1 Answers1

0
  • (-1,) is a one element tuple containing -1.

  • value.shape[2:] selects from the third to last elements from value.shape (the shape of tensor value).

All in all, what happens is the tuple gets concatenated with the torch.Size object to make a new tuple. Let's take an example tensor:

>>> x = torch.rand(2, 3, 63, 64)

>>> x.shape[2:]
torch.Size([64, 64])

>>> (-1,) + x.shape[2:]
(-1, 64, 64)

When using -1 in torch.reshape, it indicates to 'put the rest of the dimensions on that axis'. Here it will essentially flatten the first and second axes (batch and channel) together.

In our example, the shape will go from (2, 3, 64, 64) to (6, 64, 64), i.e. if the tensor has four dimensions, the operation is equivalent to

value.reshape(value.size(0)*value.size(1), value.size(2), value.size(3))

but is certainly very clumsy to write it this way.

Ivan
  • 34,531
  • 8
  • 55
  • 100