Suppose I have a 3D array:
>>> a
array([[[7, 0],
[3, 6]],
[[2, 4],
[5, 1]]])
I can get its argmax
along axis=1
using
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
[1, 0]])
How can I use m
as an index into a
, so that the results are equivalent to simply using max
?
>>> a.max(axis=1)
array([[7, 6],
[5, 4]])
(This is useful when m
is applied to other arrays of the same shape)