0

I want to index into the last axis of a tensor with an arbitrary shape, except for the last which is 2.

e.g. Let x be of the shape (1,2,2). Index to the last axis by

x_0 = x[:, :, 0]    # x_0, x_1 shapes are (1,2)
x_1 = x[:, :, 1]

e.g. Let x be of the shape (1,2,3,4,2). Index to the last axis by

x_0 = x[:, :, :, :, 0]   # x_0, x_1 shapes are (1,2,3,4)
x_1 = x[:, :, :, :, 1]

I've been unable to find any tensorflow function or usage for slicing an arbitrary shape.

I need a general method to index, such that I can always access the last axis for any shape tensor.

kt-kbr
  • 3
  • 1

1 Answers1

0

The slice syntax in tensorflow is very similar to . You can use the ellipsis in that case:

Ellipsis expands to the number of : objects needed for the selection tuple to index all dimensions. In most cases, this means that length of the expanded selection tuple is x.ndim. There may only be a single ellipsis present.

In your case,

x_0 = x[..., 0]

will index the last axis of a tensor with an arbitrary shape.

You can also look at the answer to the question: What is the difference between the slice (:) and the ellipsis (…) operators in numpy?.

Lescurel
  • 10,749
  • 16
  • 39