6

Suppose I have a 3D array:

>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])

I can get its argmax along axis=1 using

>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])

How can I use m as an index into a, so that the results are equivalent to simply using max?

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])

(This is useful when m is applied to other arrays of the same shape)

MWB
  • 11,740
  • 6
  • 46
  • 91

2 Answers2

5

You can do this with advanced indexing and numpy broadcasting:

m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])

m = np.argmax(a, axis=1)

Create arrays of 1st, 2nd and 3rd dimensions indices:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
​

Because of the dimension mismatch, the three arrays will broadcast, which result in each to be as follows:

for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 

And since all indices are integer arrays, it triggers advanced indexing, so elements with indices (0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1) are picked up, i.e. one element from each array combined as the index.

Psidom
  • 209,562
  • 33
  • 339
  • 356
3

You can use np.ogrid to create a grid over all axis for your array except the one you reduced. And then just insert the argmax result at the position of your axis and index your array with the result:

>>> import numpy as np
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]])
>>> axis = 1

>>> # Create the grid
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]])
>>> argmaxes = np.argmax(a, axis=axis)
>>> idx.insert(axis, argmaxes)

>>> # Index the original array with the grid
>>> a[idx]
array([[7, 6],
       [5, 4]])

Note that this doesn't work for axis=None or in case you reduced over multiple axis.

MSeifert
  • 145,886
  • 38
  • 333
  • 352