The return value of argmax
along an axis can't be simply used as an index. It only works in a 1d case.
In [124]: u = np.arange(12).reshape(3,4,1)
In [125]: e = u.argmax(axis=2)
In [126]: u.shape
Out[126]: (3, 4, 1)
In [127]: e.shape
Out[127]: (3, 4)
e
is (3,4), but its values only index the last dimension of u
.
In [128]: u[e].shape
Out[128]: (3, 4, 4, 1)
Instead we have to construct indices for the other 2 dimensions, ones which broadcast with e
. For example:
In [129]: I,J=np.ix_(range(3),range(4))
In [130]: I
Out[130]:
array([[0],
[1],
[2]])
In [131]: J
Out[131]: array([[0, 1, 2, 3]])
Those are (3,1) and (1,4). Those are compatible with (3,4) e
and the desired output
In [132]: u[I,J,e]
Out[132]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
This kind of question has been asked before, so probably should be marked as a duplicate. The fact that your last dimension is size 1, and hence e
is all 0s, distracting readers from the underlying issue (using a multidimensional argmax
as index).
numpy: how to get a max from an argmax result
Get indices of numpy.argmax elements over an axis
Assuming you've taken the argmax on the last dimension
In [156]: ij = np.indices(u.shape[:-1])
In [157]: u[(*ij,e)]
Out[157]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
or:
ij = np.ix_(*[range(i) for i in u.shape[:-1]])
If the axis is in the middle, it'll take a bit more tuple fiddling to arrange the ij
elements and e
.