4

Given an 2-dimensional tensor in numpy (or in pytorch), I can partially slice along all dimensions at once as follows:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

How can I achieve the same slicing pattern regardless of the number of dimensions in the tensor if I do not know the number of dimensions at implementation time? (i.e. I want a[1:] if a has only one dimension, a[1:,1:] for two dimensions, a[1:,1:,1:] for three dimensions, and so on)

It would be nice if I could do it in a single line of code like the following, but this is invalid:

a[(1:,) * len(a.shape)]  # SyntaxError: invalid syntax

I am specifically interested in a solution that works for pytorch tensors (just substitute torch for numpy above and the example is the same), but I figure it is likely and best if the solution works for both numpy and pytorch.

teichert
  • 3,963
  • 1
  • 31
  • 37

1 Answers1

5

Answer: Making a tuple of slice objects does the trick:

a[(slice(1,None),) * len(a.shape)]

Explanation: slice is a builtin python class (not tied to numpy or pytorch) which provides an alternative to the subscript notation for describing slices. The answer to a different question suggests using this as a way to store slice information in python variables. The python glossary points out that

The bracket (subscript) notation uses slice objects internally.

Since the __getitem__ methods for numpy ndarrays and pytorch tensors support multi-dimensional indexing with slices, they must also support multidimensional indexing with slice objects, and so we can make a tuple of those slices that of the right length.

Btw, you can see how python uses the slice objects by creating a dummy class as follows and then do slicing on it:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))


teichert
  • 3,963
  • 1
  • 31
  • 37