16

What is the most elegant way to access an n dimensional array with an (n-1) dimensional array along a given dimension as in the dummy example

a = np.random.random_sample((3,4,4))
b = np.random.random_sample((3,4,4))
idx = np.argmax(a, axis=0)

How can I access now with idx a to get the maxima in a as if I had used a.max(axis=0)? or how to retrieve the values specified by idx in b?

I thought about using np.meshgrid but I think it is an overkill. Note that the dimension axis can be any usefull axis (0,1,2) and is not known in advance. Is there an elegant way to do this?

2006pmach
  • 361
  • 1
  • 10

3 Answers3

15

Make use of advanced-indexing -

m,n = a.shape[1:]
I,J = np.ogrid[:m,:n]
a_max_values = a[idx, I, J]
b_max_values = b[idx, I, J]

For the general case:

def argmax_to_max(arr, argmax, axis):
    """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)"""
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]

Quite a bit more awkward than such a natural operation should be, unfortunately.

For indexing a n dim array with a (n-1) dim array, we could simplify it a bit to give us the grid of indices for all axes, like so -

def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)

Hence, use it to index into input arrays -

axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • 1
    As I said I don't know in advance that `axis` is 0 it can be any other value. This would change the order of `[idx, I, J]` so this will not work. And is basically the idea I mentioned with meshgrid... – 2006pmach Sep 07 '17 at 18:49
  • @user2357112 Thanks for the help! Works nicely as a generic solution. – Divakar Sep 07 '17 at 19:08
  • 2
    `all_idx` is really elegant. Love your posts, @Divakar! – unutbu Sep 07 '17 at 20:04
  • Why did you choose to make `all_idx` return a tuple? Is there a case where returning `grid` (as a list) would not suffice? – unutbu Sep 07 '17 at 20:05
  • @unutbu Well `all_idx` is mostly based on the edits by user2357112 and that had `tuple`. At my end with Python2.x, it worked fine without the tuple, but I wasn't sure if this would have worked on Python3.x. Need to test that out. Any insights? – Divakar Sep 07 '17 at 20:07
  • Per [the docs](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing), advanced indexing is triggered by any non-tuple sequence object, so I think letting `grid` be a list should work. – unutbu Sep 07 '17 at 20:14
  • 2
    @unutbu Ah I see now! So, if the input `a` is a 1D array, then with tuple, we would get a scalar, which replicates `.max()` behavior. But without tuple we would get an array with one elem. So, maybe that's one reason to keep it as a tuple I think. – Divakar Sep 07 '17 at 20:22
  • 2
    `all_idx` is nice. I didn't realize using the argmax output's shape instead of the original array's shape would simplify things. As for the tuple instead of a list, the advanced indexing semantics are cleaner with a tuple. The list happens to behave the same in this case due to a (not-quite-correctly-documented) bit of backward compatibility handling that converts a list to a tuple under certain conditions. It can be surprising when a list is treated like a tuple and when a list is treated like an array in NumPy indexing, so I prefer to create the tuple explicitly. – user2357112 Sep 07 '17 at 20:24
  • 2
    For reference, the backward compatibility handling I'm referring to is [here](https://github.com/numpy/numpy/blob/7ccf0e08917d27bc0eba34013c1822b00a66ca6d/numpy/core/src/multiarray/mapping.c#L200). The case where `idx` is a scalar is something I didn't even think about, where the backward compatibility handling isn't triggered and the wrong result would occur if we kept a list instead of a tuple. As this demonstrates, the tuple's behavior is more consistent and easier to predict. – user2357112 Sep 07 '17 at 21:52
  • Subtle thing I realized later when trying to adapt this technique to a different problem: `np.ogrid[:5]` is `array([0, 1, 2, 3, 4])`, but `np.ogrid[:5,]`, with the tuple, is `[array([0, 1, 2, 3, 4])]`. Fortunately, things work out the way we need them to for this answer. – user2357112 Oct 31 '17 at 21:54
  • 2
    `take_along_axis` has been added to make this easier. And a `put` – hpaulj Jan 17 '19 at 19:47
  • 1) This is amazing. 2) This doesn't work if axis is negative – DilithiumMatrix Feb 25 '19 at 04:50
  • @hpaulj Thanks. `np.squeeze(np.take_along_axis(a, np.expand_dims(idx, axis=axis), axis), axis=axis)` gives the same result as `a[all_idx(idx, axis)]` – bartolo-otrit Aug 29 '20 at 16:21
-1

I suggest the following:

a = np.array([[1, 3], [2, -2], [1, -1]])
a
>array([[ 1,  3],
       [ 2, -2],
       [ 1, -1]])

idx = a.argmax(axis=1)
idx
> array([1, 0, 0], dtype=int64)
    
np.take_along_axis(a, idx[:, None], axis=1).squeeze()
>array([3, 2, 1])

a.max(axis=1)
>array([3, 2, 1])
bpfrd
  • 945
  • 3
  • 11
  • 1
    As it’s currently written, your answer is unclear. Please [edit] to add additional details that will help others understand how this addresses the question asked. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Jul 08 '22 at 14:42
-1

using indexing in numpy https://docs.scipy.org/doc/numpy-1.10.1/reference/arrays.indexing.html#advanced-indexing

a = np.array([[1, 2], [3, 4], [5, 6]])
a
> a: array([[1, 2], 
           [3, 4],
           [5, 6]])
idx = a.argmax(axis=1)
idx
> array([1, 0, 0], dtype=int64)

since you want all rows but only columns with idx indexes you can use [0, 1, 2] or np.arange(a.shape[0]) for the row indexes

rows = np.arange(a.shape[0])
a[rows, idx]
>array([3, 2, 1])

which is the same as a.max(axis=1)

a.max(axis=1)
>array([3, 2, 1])

if you have 3 dimensions you add the indexes of the 3rd dimension as well:

index2 = np.arange(a.shape[2]) 
a[rows, idx, index2]
bpfrd
  • 945
  • 3
  • 11