The Approach
We are going to take advantage of the Numpy community and libraries, as well as the fact that Pytorch tensors and Numpy arrays can be converted to/from one another without copying or moving the underlying arrays in memory (so conversions are low cost). From the Pytorch documentation:
Converting a torch Tensor to a Numpy array and vice versa is a breeze. The torch Tensor and Numpy array will share their underlying memory locations, and changing one will change the other.
Solution One
We are first going to use the Numba library to write a function that will be just-in-time (JIT) compiled upon its first usage, meaning we can get C speeds without having to write C code ourselves. Of course, there are caveats to what can get JIT-ed, and one of those caveats is that we work with Numpy functions. But this isn't too bad because, remember, converting from our torch tensor to Numpy is low cost. The function we create is:
@njit(cache=True)
def indexFunc(array, item):
for idx, val in np.ndenumerate(array):
if val == item:
return idx
This function if from another Stackoverflow answer located here (This was the answer which introduced me to Numba). The function takes an N-Dimensional Numpy array and looks for the first occurrence of a given item
. It immediately returns the index of the found item on a successful match. The @njit
decorator is short for @jit(nopython=True)
, and tells the compiler that we want it to compile the function using no Python objects, and to throw an error if it is not able to do so (Numba is the fastest when no Python objects are used, and speed is what we are after).
With this speedy function backing us, we can get the indices of the max values in a tensor as follows:
import numpy as np
x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
max_indices[index] = np.asarray(indexFunc(x[index], maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)
We use np.amax
because it can accept a tuple for its axis
argument, allowing it to return the max values of each 2D feature map in the 4D input. We initialize max_indices
with np.zeros
ahead of time because appending to numpy arrays is expensive, so we allocate the space we need ahead of time. This approach is much faster than the Typical Solution in the question (by an order of magnitude), but it also uses a for
loop outside the JIT-ed function, so we can improve...
Solution Two
We will use the following solution:
@njit(cache=True)
def indexFunc(array, item):
for idx, val in np.ndenumerate(array):
if val == item:
return idx
raise RuntimeError
@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
return max_indices
x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))
Instead of iterating through our feature maps one-at-a-time with a for
loop, we can take advantage of parallelization using Numba's prange
function (which behaves exactly like range
but tells the compiler we want the loop to be parallelized) and the parallel=True
decorator argument. Numba also parallelizes the np.zeros
function. Because our function is compiled Just-In-Time and uses no Python objects, Numba can take advantage of all the threads available in our system! It is worth noting that there is now a raise RuntimeError
in the indexFunc
. We need to include this, otherwise the Numba compiler will try to infer the return type of the function and infer that it will either be an array or None. This doesn't jive with our usage in indexFunc2
, so the compiler would throw an error. Of course, from our setup we know that indexFunc
will always return an array, so we can simply raise and error in the other logical branch.
This approach is functionally identical to Solution One, but changes the iteration using nd.index
into two for
loops using prange
. This approach is about 4x faster than Solution One.
Solution Three
Solution Two is fast, but it is still finding the max values using regular Python. Can we speed this up using a more comprehensive JIT-ed function?
@njit(cache=True)
def indexFunc(array, item):
for idx, val in np.ndenumerate(array):
if val == item:
return idx
raise RuntimeError
@njit(cache=True, parallel=True)
def indexFunc3(x):
maxVals = np.zeros((x.shape[0],x.shape[1]),dtype=np.float32)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
maxVals[i][j] = np.max(x[i][j])
max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
x[i][j] == np.max(x[i][j])
max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
return max_indices
max_indices = torch.from_numpy(indexFunc3(x))
It might look like there is a lot more going on in this solution, but the only change is that instead of calculating the maximum values of each feature map using np.amax
, we have now parallelized the operation. This approach is marginally faster than Solution Two.
Solution Four
This solution is the best I've been able to come up with:
@njit(cache=True, parallel=True)
def indexFunc4(x):
max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
maxTemp = np.argmax(x[i][j])
max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]]
return max_indices
max_indices = torch.from_numpy(indexFunc4(x))
This approach is more condensed and also the fastest at 33% faster than Solution Three and 50x faster than the Typical Solution. We use np.argmax
to get the index of the max value of each feature map, but np.argmax
only returns the index as if each feature map were flattened. That is, we get a single integer telling us which number the element is in our feature map, not the indices we need to be able to access that element. The math [maxTemp // x.shape[2], maxTemp % x.shape[2]]
is to turn that singular int into the [row,column]
that we need.
Benchmarking
All approaches were benchmarked together against a random input of shape [32,d,64,64]
, where d was incremented from 5 to 245. For each d, 15 samples were gathered and the times were averaged. An equality test ensured that all solutions provided identical values. An example of the benchmark output is:

A plot of the benchmarking times as d increased is (leaving out the Typical Solution so the graph isn't squashed):

Woah! What is going on at the start with those spikes?
Solution Five
Numba allows us to produce Just-In-Time compiled functions, but it doesn't compile them until the first time we use them; It then caches the result for when we call the function again. This means the very first time we call our JIT-ed functions we get a spike in compute time as the function is compiled. Luckily, there is a way around this- if we specify ahead of time what our function's return type and argument types will be, the function will be eagerly compiled instead of compiled just-in-time. Applying this knowledge to Solution Four we get:
@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
def indexFunc4(x):
max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
maxTemp = np.argmax(x[i][j])
max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]]
return max_indices
max_indices6 = torch.from_numpy(indexFunc4(x))
And if we restart our kernel and rerun our benchmark, we can look at the first result where d==5
and the second result where d==10
and note that all of the JIT-ed solutions were slower when d==5
because they had to be compiled, except for Solution Four, because we explicitly provided the function signature ahead of time:

There we go! That's the best solution I have so far for this problem.
EDIT #1
Solution Six
An improved solution has been developed which is 33% faster than the previously posted best solution. This solution only works if the input array is C-contiguous, but this isn't a big restriction since numpy arrays or torch tensors will be contiguous unless they are reshaped, and both have functions to make the array/tensor contiguous if needed.
This solution is the same as the previous best, but the function decorator which specifies the input and return types are changed from
@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
to
@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
The only difference is that the last :
in each array typing becomes ::1
, which signals to the numba njit compiler that the input arrays are C-contiguous, allowing it to better optimize.
The full solution six is then:
@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
def indexFunc5(x):
max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
for i in prange(x.shape[0]):
for j in prange(x.shape[1]):
maxTemp = np.argmax(x[i][j])
max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]]
return max_indices
max_indices7 = torch.from_numpy(indexFunc5(x))
The benchmark including this new solution confirms the speedup:
