I have a complex-valued array of about 750000 elements for which I repeatedly (say 10^6 or more times) update 1000 (or less) different elements. In the absolute-squared array I then need to find the index of the maximum. This is part of a larger code which takes about ~700 seconds to run. Out of these, typically 75% (~550 sec) are spent on finding the index of the maximum. Even though ndarray.argmax()
is "blazingly fast" according to https://stackoverflow.com/a/26820109/5269892, running it repeatedly on an array of 750000 elements (even though only 1000 elements have been changed) just takes too much time.
Below is a minimal, complete example, in which I use random numbers and indices. You may not make assumptions about how the real-valued array 'b'
changes after an update (i.e. the values may be larger, smaller or equal), except, if you must, that the array at the index of the previous maximum value ('b[imax]'
) will likely be smaller after an update.
I tried using sorted arrays into which only the updated values (in sorted order) are inserted at the correct place to maintain sorting, because then we know the maximum is always at index -1
and we do not have to recompute it. The minimal example below includes timings. Unfortunately, selecting the non-updated values and inserting the updated values takes too much time (all other steps combined would require only ~210 us instead of the ~580 us of the ndarray.argmax()
).
Context: This is part of an implementation of the deconvolution algorithm CLEAN (Hoegbom, 1974) in the efficient Clark (1980) variant. As I intend to implement the Sequence CLEAN algorithm (Bose+, 2002), where even more iterations are required, or maybe want to use larger input arrays, my question is:
Question: What is the fastest way to find the index of the maximum value in the updated array (without applying ndarray.argmax()
to the whole array in each iteration)?
Minimal example code (run on python 3.7.6, numpy 1.21.2, scipy 1.6.0
):
import numpy as np
# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel'; here
# 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000
# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued), and
# two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))
# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]
def do_update_iter(a, b):
# find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
# index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
imax = b.argmax()
ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
a[ichange] = a_change
b[ichange] = a_change.real ** 2 + a_change.imag ** 2
return a, b, ichange
# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)
# sort 'a', 'b', 'inu_sel' and 'im_sel'
i_sort = b.argsort()
a_sort, b_sort, inu_sort, im_sort = a[i_sort], b[i_sort], inu_sel[i_sort], im_sel[i_sort]
# do an update iteration on 'a_sort' and 'b_sort'
a_sort, b_sort, ichange = do_update_iter(a_sort, b_sort)
b_sort_copy = b_sort.copy()
ind = np.arange(nsel)
def binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange):
# binary insertion as an idea to save computation time relative to repeated argmax over entire (large) arrays
# find updated values for 'a_sort', compute updated values for 'b_sort'
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
# sort the updated values for 'a_sort' and 'b_sort' as well as the corresponding indices
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
# find indices of the non-updated values, cut out those indices from 'a_sort', 'b_sort', 'inu_sort' and 'im_sort'
ind_complement = np.delete(ind, ichange)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
# find indices where sorted updated elements would have to be inserted into the sorted non-updated arrays to keep
# the merged arrays sorted and insert the elements at those indices
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)
return a_updated, b_updated, inu_updated, im_updated
# do the binary insertion
a_updated, b_updated, inu_updated, im_updated = binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)
# do all the steps of the binary insertion, just to have the variable names defined
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
ind_complement = np.delete(ind, i_sort)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)
# timings for argmax and for sorting
%timeit b.argmax() # 579 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit b_sort.argmax() # 580 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.sort(b) # 70.2 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.sort(b_sort) # 25.2 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit b_sort_copy.sort() # 14 ms ± 78.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# timings for binary insertion
%timeit binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange) # 33.7 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a_change = a_sort[ichange] # 4.28 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change = a_change.real ** 2 + a_change.imag ** 2 # 8.25 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit i_sort = b_change.argsort() # 35.6 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_change_sort = a_change[i_sort] # 4.2 µs ± 62.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change_sort = b_change[i_sort] # 2.05 µs ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inu_change_sort = inu_sort[ichange][i_sort] # 4.47 µs ± 38 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit im_change_sort = im_sort[ichange][i_sort] # 4.51 µs ± 48.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit ind_complement = np.delete(ind, ichange) # 1.38 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit a_complement = a_sort[ind_complement] # 3.52 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_complement = b_sort[ind_complement] # 1.44 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_complement = inu_sort[ind_complement] # 1.36 ms ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_complement = im_sort[ind_complement] # 1.31 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit i_insert = b_complement.searchsorted(b_change_sort) # 148 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_updated = np.insert(a_complement, i_insert, a_change_sort) # 3.08 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_updated = np.insert(b_complement, i_insert, b_change_sort) # 1.37 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_updated = np.insert(inu_complement, i_insert, inu_change_sort) # 1.41 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_updated = np.insert(im_complement, i_insert, im_change_sort) # 1.52 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Update: As suggested below by @Jérôme Richard, a fast way to repeatedly find the index of the maximum in a partially updated array is to split the array into chunks, pre-compute the maxima of the chunks, and then in each iteration re-compute the maxima of only the nchange
(or less) updated chunks, followed by computing the argmax over the chunk maxima, returning the chunk index, and finding the argmax within the chunk of that chunk index.
I copied the code from @Jérôme Richard's answer. In practice, his solution, when run on my system, results in a speed-boost of about 7.3, requiring 46.6 + 33 = 79.6 musec instead of 580 musec for b.argmax()
.
import numba as nb
@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(b):
# Required for this simplified version to work and be simple
assert b.size % 32 == 0
max_per_chunk = np.empty(b.size // 32)
for chunk_idx in nb.prange(b.size//32):
offset = chunk_idx * 32
maxi = b[offset]
for j in range(1, 32):
maxi = max(b[offset + j], maxi)
max_per_chunk[chunk_idx] = maxi
return max_per_chunk
# OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(b, max_per_chunk):
# Required for this simplified version to work and be simple
assert b.size % 32 == 0
assert max_per_chunk.size == b.size // 32
chunk_idx = np.argmax(max_per_chunk)
offset = chunk_idx * 32
return offset + np.argmax(b[offset:offset+32])
@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(b, max_per_chunk, ichange):
# Required for this simplified version to work and be simple
assert b.size % 32 == 0
assert max_per_chunk.size == b.size // 32
for idx in ichange:
chunk_idx = idx // 32
offset = chunk_idx * 32
maxi = b[offset]
for j in range(1, 32):
maxi = max(b[offset + j], maxi)
max_per_chunk[chunk_idx] = maxi
b = np.random.rand(nsel)
max_per_chunk = precompute_max_per_chunk(b)
a, b, ichange = do_update_iter(a, b)
argmax_from_chunks(b, max_per_chunk)
update_max_per_chunk(b, max_per_chunk, ichange)
%timeit max_per_chunk = precompute_max_per_chunk(b) # 77.3 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks(b, max_per_chunk) # 46.6 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk(b, max_per_chunk, ichange) # 33 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Update 2: I now modified @Jérôme Richard's solution to work with arrays b
having a size not equal to an integer multiple of the chunk size. In addition the code only accesses all chunk values if an updated value is smaller than the previous chunk maximum, else directly sets the updated value as the new chunk maximum. The if-queries should require a small time compared to the time saving when the updated value is larger than the previous maximum. In my code, this case will become more and more likely the more iterations have passed (the updated values get closer and closer to noise, i.e. random). In practice, for random numbers, the execution time for update_max_per_chunk()
gets reduced a bit further, from ~33 us to ~27 us. The code and new timings are:
import math
@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk_bp(b):
nchunks = math.ceil(b.size/32)
imod = b.size % 32
max_per_chunk = np.empty(nchunks)
for chunk_idx in nb.prange(nchunks):
offset = chunk_idx * 32
maxi = b[offset]
if (chunk_idx != (nchunks - 1)) or (not imod):
iend = 32
else:
iend = imod
for j in range(1, iend):
maxi = max(b[offset + j], maxi)
max_per_chunk[chunk_idx] = maxi
return max_per_chunk
@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks_bp(b, max_per_chunk):
nchunks = max_per_chunk.size
imod = b.size % 32
chunk_idx = np.argmax(max_per_chunk)
offset = chunk_idx * 32
if (chunk_idx != (nchunks - 1)) or (not imod):
return offset + np.argmax(b[offset:offset+32])
else:
return offset + np.argmax(b[offset:offset+imod])
@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk_bp(b, max_per_chunk, ichange):
nchunks = max_per_chunk.size
imod = b.size % 32
for idx in ichange:
chunk_idx = idx // 32
if b[idx] < max_per_chunk[chunk_idx]:
offset = chunk_idx * 32
if (chunk_idx != (nchunks - 1)) or (not imod):
iend = 32
else:
iend = imod
maxi = b[offset]
for j in range(1, iend):
maxi = max(b[offset + j], maxi)
max_per_chunk[chunk_idx] = maxi
else:
max_per_chunk[chunk_idx] = b[idx]
%timeit max_per_chunk = precompute_max_per_chunk_bp(b) # 74.6 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks_bp(b, max_per_chunk) # 46.6 µs ± 9.92 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk_bp(b, max_per_chunk, ichange) # 26.5 µs ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)