3

I have a complex-valued array of about 750000 elements for which I repeatedly (say 10^6 or more times) update 1000 (or less) different elements. In the absolute-squared array I then need to find the index of the maximum. This is part of a larger code which takes about ~700 seconds to run. Out of these, typically 75% (~550 sec) are spent on finding the index of the maximum. Even though ndarray.argmax() is "blazingly fast" according to https://stackoverflow.com/a/26820109/5269892, running it repeatedly on an array of 750000 elements (even though only 1000 elements have been changed) just takes too much time.

Below is a minimal, complete example, in which I use random numbers and indices. You may not make assumptions about how the real-valued array 'b' changes after an update (i.e. the values may be larger, smaller or equal), except, if you must, that the array at the index of the previous maximum value ('b[imax]') will likely be smaller after an update.

I tried using sorted arrays into which only the updated values (in sorted order) are inserted at the correct place to maintain sorting, because then we know the maximum is always at index -1 and we do not have to recompute it. The minimal example below includes timings. Unfortunately, selecting the non-updated values and inserting the updated values takes too much time (all other steps combined would require only ~210 us instead of the ~580 us of the ndarray.argmax()).

Context: This is part of an implementation of the deconvolution algorithm CLEAN (Hoegbom, 1974) in the efficient Clark (1980) variant. As I intend to implement the Sequence CLEAN algorithm (Bose+, 2002), where even more iterations are required, or maybe want to use larger input arrays, my question is:

Question: What is the fastest way to find the index of the maximum value in the updated array (without applying ndarray.argmax() to the whole array in each iteration)?

Minimal example code (run on python 3.7.6, numpy 1.21.2, scipy 1.6.0):

import numpy as np

# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel'; here
# 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000

# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued), and
# two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))

# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]

def do_update_iter(a, b):
    # find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
    # index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
    imax = b.argmax()
    ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
    a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
    a[ichange] = a_change
    b[ichange] = a_change.real ** 2 + a_change.imag ** 2
    return a, b, ichange

# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)

# sort 'a', 'b', 'inu_sel' and 'im_sel'
i_sort = b.argsort()
a_sort, b_sort, inu_sort, im_sort = a[i_sort], b[i_sort], inu_sel[i_sort], im_sel[i_sort]

# do an update iteration on 'a_sort' and 'b_sort'
a_sort, b_sort, ichange = do_update_iter(a_sort, b_sort)
b_sort_copy = b_sort.copy()

ind = np.arange(nsel)

def binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange):
    # binary insertion as an idea to save computation time relative to repeated argmax over entire (large) arrays
    # find updated values for 'a_sort', compute updated values for 'b_sort'
    a_change = a_sort[ichange]
    b_change = a_change.real ** 2 + a_change.imag ** 2
    # sort the updated values for 'a_sort' and 'b_sort' as well as the corresponding indices
    i_sort = b_change.argsort()
    a_change_sort = a_change[i_sort]
    b_change_sort = b_change[i_sort]
    inu_change_sort = inu_sort[ichange][i_sort]
    im_change_sort = im_sort[ichange][i_sort]
    # find indices of the non-updated values, cut out those indices from 'a_sort', 'b_sort', 'inu_sort' and 'im_sort'
    ind_complement = np.delete(ind, ichange)
    a_complement = a_sort[ind_complement]
    b_complement = b_sort[ind_complement]
    inu_complement = inu_sort[ind_complement]
    im_complement = im_sort[ind_complement]
    # find indices where sorted updated elements would have to be inserted into the sorted non-updated arrays to keep
    # the merged arrays sorted and insert the elements at those indices
    i_insert = b_complement.searchsorted(b_change_sort)
    a_updated = np.insert(a_complement, i_insert, a_change_sort)
    b_updated = np.insert(b_complement, i_insert, b_change_sort)
    inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
    im_updated = np.insert(im_complement, i_insert, im_change_sort)

    return a_updated, b_updated, inu_updated, im_updated

# do the binary insertion
a_updated, b_updated, inu_updated, im_updated = binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)

# do all the steps of the binary insertion, just to have the variable names defined
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
ind_complement = np.delete(ind, i_sort)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)

# timings for argmax and for sorting
%timeit b.argmax()             # 579 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit b_sort.argmax()        # 580 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.sort(b)             # 70.2 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.sort(b_sort)        # 25.2 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit b_sort_copy.sort()     # 14 ms ± 78.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# timings for binary insertion
%timeit binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)          # 33.7 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a_change = a_sort[ichange]                                         # 4.28 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change = a_change.real ** 2 + a_change.imag ** 2                 # 8.25 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit i_sort = b_change.argsort()                                        # 35.6 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_change_sort = a_change[i_sort]                                   # 4.2 µs ± 62.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change_sort = b_change[i_sort]                                   # 2.05 µs ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inu_change_sort = inu_sort[ichange][i_sort]                        # 4.47 µs ± 38 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit im_change_sort = im_sort[ichange][i_sort]                          # 4.51 µs ± 48.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit ind_complement = np.delete(ind, ichange)                           # 1.38 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit a_complement = a_sort[ind_complement]                              # 3.52 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_complement = b_sort[ind_complement]                              # 1.44 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_complement = inu_sort[ind_complement]                          # 1.36 ms ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_complement = im_sort[ind_complement]                            # 1.31 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit i_insert = b_complement.searchsorted(b_change_sort)                # 148 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_updated = np.insert(a_complement, i_insert, a_change_sort)       # 3.08 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_updated = np.insert(b_complement, i_insert, b_change_sort)       # 1.37 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_updated = np.insert(inu_complement, i_insert, inu_change_sort) # 1.41 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_updated = np.insert(im_complement, i_insert, im_change_sort)    # 1.52 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Update: As suggested below by @Jérôme Richard, a fast way to repeatedly find the index of the maximum in a partially updated array is to split the array into chunks, pre-compute the maxima of the chunks, and then in each iteration re-compute the maxima of only the nchange (or less) updated chunks, followed by computing the argmax over the chunk maxima, returning the chunk index, and finding the argmax within the chunk of that chunk index.

I copied the code from @Jérôme Richard's answer. In practice, his solution, when run on my system, results in a speed-boost of about 7.3, requiring 46.6 + 33 = 79.6 musec instead of 580 musec for b.argmax().

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(b):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk
# OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(b, max_per_chunk):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(b, max_per_chunk, ichange):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

b = np.random.rand(nsel)
max_per_chunk = precompute_max_per_chunk(b)
a, b, ichange = do_update_iter(a, b)
argmax_from_chunks(b, max_per_chunk)
update_max_per_chunk(b, max_per_chunk, ichange)

%timeit max_per_chunk = precompute_max_per_chunk(b)     # 77.3 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks(b, max_per_chunk)            # 46.6 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk(b, max_per_chunk, ichange) # 33 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Update 2: I now modified @Jérôme Richard's solution to work with arrays b having a size not equal to an integer multiple of the chunk size. In addition the code only accesses all chunk values if an updated value is smaller than the previous chunk maximum, else directly sets the updated value as the new chunk maximum. The if-queries should require a small time compared to the time saving when the updated value is larger than the previous maximum. In my code, this case will become more and more likely the more iterations have passed (the updated values get closer and closer to noise, i.e. random). In practice, for random numbers, the execution time for update_max_per_chunk() gets reduced a bit further, from ~33 us to ~27 us. The code and new timings are:

import math

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk_bp(b):
    nchunks = math.ceil(b.size/32)
    imod = b.size % 32
    max_per_chunk = np.empty(nchunks)
    
    for chunk_idx in nb.prange(nchunks):
        offset = chunk_idx * 32
        maxi = b[offset]
        if (chunk_idx != (nchunks - 1)) or (not imod):
            iend = 32
        else:
            iend = imod
        for j in range(1, iend):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks_bp(b, max_per_chunk):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    if (chunk_idx != (nchunks - 1)) or (not imod):
        return offset + np.argmax(b[offset:offset+32])
    else:
        return offset + np.argmax(b[offset:offset+imod])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk_bp(b, max_per_chunk, ichange):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    for idx in ichange:
        chunk_idx = idx // 32
        if b[idx] < max_per_chunk[chunk_idx]:
            offset = chunk_idx * 32
            if (chunk_idx != (nchunks - 1)) or (not imod):
                iend = 32
            else:
                iend = imod
            maxi = b[offset]
            for j in range(1, iend):
                maxi = max(b[offset + j], maxi)
            max_per_chunk[chunk_idx] = maxi
        else:
            max_per_chunk[chunk_idx] = b[idx]

%timeit max_per_chunk = precompute_max_per_chunk_bp(b)     # 74.6 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks_bp(b, max_per_chunk)            # 46.6 µs ± 9.92 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk_bp(b, max_per_chunk, ichange) # 26.5 µs ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
bproxauf
  • 1,076
  • 12
  • 23
  • Has anyone mentioned [sparse matrices](https://docs.scipy.org/doc/scipy/reference/sparse.html)? You could keep the updates in separate sparse matrices and only find the max of the updated elements that way. (I assume max(sparse_array) only checks the non-zero elements). – Bill Jul 20 '22 at 18:33
  • Finding the max of only the updated elements does not work, as the overall max may be among the non-updated elements. And I guess the conversions from numpy arrays to sparse matrices and vice versa in each iteration would take too long (but I don't know for sure)... – bproxauf Jul 20 '22 at 18:36
  • Indeed. You would need to keep a record of the max from previous iterations and compare the max of the latest sparse array to that like this: `current_max = max(previous_max, sparse_array.max())`. I don't know if creating a sparse matrix is a big overhead. Maybe not. – Bill Jul 20 '22 at 18:39
  • Have not used sparse matrices so far, but I can play around with that. Creating a single sparse matrix may not cost much time, true, but with increasing iterations, due to the changing indices of the updates, the matrix gets less and less sparse and when a lot of the elements are non-zero, it should get quite slow. Unless we reset the matrix to zero once we found the maximum, of course. – bproxauf Jul 20 '22 at 18:45
  • 1
    Sorry, I was missing the fact that values are replaced each iteration and therefore the `previous_max` would become invalid. You would have to keep a record of say the 'top N' values and then periodically re-sort the whole array every time the top N become depleted. – Bill Jul 20 '22 at 20:20

1 Answers1

2

ndarray.argmax() is "blazingly fast" according to https://stackoverflow.com/a/26820109/5269892

Argmax is not optimal since it does not succeed to saturate the RAM bandwidth on my machine (which is possible), but it is very good since it saturate ~40% of the total RAM throughput in your case and about 65%-70% in sequential on my machine (one core cannot saturate the RAM on most machine). Most machine have a lower throughput so np.argmax should be even closer to the optimal on these machine.

Finding the maximum value using multiple threads can help to reach the optimal but regarding the current performance of the funciton, one should not expect a speed up greater than 2 on most PC (more on computing servers).

What is the fastest way to find the index of the maximum value in the updated array

Whatever the computation done, reading the whole array in memory takes at least b.size * 8 / RAM_throughput seconds. With a very-good 2-channels DDR4 RAM, the optimal time is about to ~125 us, while the best 1-channel DDR4 RAM achieve ~225 us. If the array is written in-place, the optimal times is twice bigger and if a new array is created (out-of-place computation), then it is 3 time bigger on x86-64 platforms. In fact, this is even worse for the latter because of big overheads of the OS virtual memory.

What this means is that no out-of-place computation reading the whole array can beat np.argmax on a mainstream PC. This also explains why the sort solution is so slow: it creates many temporary arrays. Even a perfect sorted array strategy would be not much faster than np.argmax here (because all items need to be moved in RAM in the worst case and far more than half in average). In fact, the benefit of any in-place methods writing the whole array is low (still on a mainstream PC): it would only be slightly faster than np.argmax. The only solution to get a significant speed up is not to operate on the whole array.

One efficient solution to solve this problem is to use a balanced binary search tree. Indeed, you can remove the k nodes from a tree containing n nodes in O(k log n) time. You can then insert the updated values in the same time. This is much better than a O(n) solution in your case because n ~= 750_000 and k ~= 1_000. Still, note that there is an hidden factor behind the complexity and binary search tree may not be so fast in practice, especially if they are not very optimized. Also note that it is better to update the tree value than to delete nodes and insert new ones. A pure-Python implementation will hardly be fast enough in this case (and take a lot a memory). Only **Cython or a native solution can be fast (eg. C/C++, or any Python module implemented natively but I could not find any one that are fast).

Another alternative is a static n-ary tree-based partial maximums data structure. It consist in splitting the array in chunks and pre-computing the maximum of each chunks first. When values are updated (and assuming the number of items is constant), you need to (1) recompute the maximum of each chunk. To compute the global maximum, you need to (2) compute the maximum of each chunk maximum value. This solution also require a (semi) native implementation so to be fast since Numpy introduces significant overheads during the update of the per-chunk maximum values (because it is not very optimized for such a case), but one should certainly see a speed up. Numba and Cython can be used to do so for example. The size of the chunks need to be carefully chosen. In your case something between 16 to 32 should gives you a huge speed up.

With chunks of size 32, only at most 32*k=32_000 values needs to be read to recompute the total maximum (up to 1000 values are written). This is far less than 750_000. The update of the partial maximums require to compute the maximum value of a n/32 ~= 23_400 values which is still relatively small. I expect this to be at 5 time faster with an optimized implementation, but probably even >10 times faster in practice, especially, using a parallel implementation. This is certainly the best solution (without additional assumptions).


Implementation in Numba

Here is a (barely tested) Numba implementation:

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(arr):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(arr, max_per_chunk):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(arr, max_per_chunk, ichange):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

Here is an example of how to use it and timings on my (6-core) machine:

# Precomputation (306 µs)
max_per_chunk = precompute_max_per_chunk(b)

# Computation of the global max from the chunks (22.3 µs)
argmax_from_chunks(b, max_per_chunk)

# Update of the chunks (25.2 µs)
update_max_per_chunk(b, max_per_chunk, ichange)

# Initial best implementation: 357 µs
np.argmax(b)

As you can see, it is pretty fast. Updates should takes 22.3+25.2 = 47.5 µs, while the Numpy naive implementation takes 357 µs. So the Numba implementation is 7.5 times faster! I think it can be optimized a bit further but it is not simple. Note the update is sequential and the pre-computation is parallel. Fun fact: the pre-computation followed by a call to argmax_from_chunks is faster than np.argmax thanks to the use of multiple threads!


Further improvements

The argmax_from_chunks can be improved thanks to SIMD instruction. Indeed, the current implementation generates the scalar maxsd/vmaxsd instruction on x86-64 machines which is sub-optimal. The operation can be vectorized by using a tile-based argmin computing the maximum with a x4 unrolled loop (possibly even x8 on recent 512-bit wide SIMD machines). On my processor supporting the AVX instruction set, experiments shows that Numba can generate a code running in 6-7 us (about 4 times faster). That being said, this is tricky to implement and the resulting function is a bit ugly.

The same method can be used to also speed up update_max_per_chunk which is unfortunately also not vectorized by default. I also expect a ~4x speed up on a recent x86-64 machine. However, Numba generate a very inefficient vectorization method in many case (it tries to vectorize the outer loop instead of the inner one). As a result my best attempt with Numba reached 16.5 us.

In theory, the whole update can be made about 4 times faster on a mainstream x86-64 machine, though in practice a 2 time faster code is at least possible!

Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • Thanks for your detailed answer. Please note that the `n` from your answer is not 75000, but 750000, which I guess would change your `k` for your chunk maxima idea. Unfortunately I did not use Cython or binary search trees so far, and since I use the arrays `a` and `b` in the further stages of the code iterations, maybe I'd have to convert these binary trees back to numpy arrays, which might be costly (I don't know). – bproxauf Jul 19 '22 at 22:03
  • 1
    Indeed, this is even better in this case :) . I advise you to use the last solution which is simpler and certainly much faster in practice in a Python environment. The binary tree conversion can indeed be expensive if the implementation is not carefully optimized and most implementation are not (especially Python modules). I tried a basic Numpy solution but the current implementation of Numpy is clearly not optimized for this case and it turns out to be very inefficient. If you are not familiar with Cython, Numba is certainly the best option as you can write everything in Python. – Jérôme Richard Jul 19 '22 at 22:31
  • 1
    I wrote a Numba implementation and got a great speed up. Calling `precompute_max_per_chunk` + `argmax_from_chunks` gives me correct values so it should be Ok but I did not tested the code thoroughly so you need to check. Note that you certainly need to tweak the implementation so to support array sizes that are not a multiple of 32. – Jérôme Richard Jul 19 '22 at 23:13
  • First time I'm using `numba`, so a few, maybe stupid questions: I guess you mean `nb.prange` instead of `np.prange`. How should the chunk size `c` be determined for arbitrary `n` (`nsel`) and `k` (`nchange`)? You mentioned 'c*k + n/c' values are read in total, so should one minimize this number with respect to `c`? Is there a specific reason that you used a chunk size being a power of 2? Is this related to cache sizes being a power of 2 or sth. like that? You mentioned that `numpy` incurs significant overhead. Do you mean that np.argmax gets slower than `O(N)` when `N` gets sufficiently small? – bproxauf Jul 20 '22 at 08:35
  • Is this the reason you are using an inner for-loop over the 32 chunk elements, using python's default `max` instead of `np.max`? Would the outer for-loop get slightly faster by using a list comprehension? And what is the `arr` argument? Is it required (seems it's not used)? Should this be `b` instead? Finally I got some OpenMP info after loading the first function: `# OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.` I guess it's not relevant, but anyways... – bproxauf Jul 20 '22 at 08:44
  • 1
    `np.max` also work here. Note that I used a loop because Numpy function usually introduce a small overhead even in Numba. While it is generally acceptable in Numba, this is a very critical loop here so I prefer having no surprise when running it. Note that list comprehension tends to be slow even in Numba because they add some overheads (and also because creating a variable-sized list is inherently slower than a fixed-size preallocated array in the first place). – Jérôme Richard Jul 20 '22 at 18:10
  • 1
    `arr` is `b` in your example (IDK for the real-world code). This can be seen in the last example code. Note that `arr` is a function parameter while `b` is the argument and they should not be confused (but feel free to change names). The OpenMP warning is weird. It seems to indicate the Numba function is running in a parallel context. If so, I advise you to test to remove the `parallel=True` and `prange` because parallel nested loops are generally inefficient. It might also be just an internal useless warning... – Jérôme Richard Jul 20 '22 at 18:15
  • 1
    I guess I meant that `arr` is not used inside the functions. Not a big issue, since it will use the global variable `b`, but anyways. Your other explanations are very helpful as well. I'll play around with removing the `parallel` and `prange` tomorrow back at work to see if the speed-up is increased even further in case it made the solution slower than necessary. Thanks! – bproxauf Jul 20 '22 at 18:22
  • 1
    Ha, Indeed. I missed that, sorry! This is now fixed. Thank you. – Jérôme Richard Jul 20 '22 at 18:27
  • Wondering whether we can speed this up even more. We do not really have to compute the max over 32 chunk values, do we? We only need to compare the maximum of the updated values that fall into a single chunk with the previous chunk maximum, so e.g. if only one updated value is in one chunk, one comparison should be sufficient (the updated value with the previous chunk maximum). – bproxauf Jul 20 '22 at 18:53
  • If we have `n/c` chunks with a length `c` and `k` updated values, we require at max to access `2k` values (if each updated value is in a separate chunk), at min `k + 1` (if all updated values fall into the same chunk). We then need to access `n/c` values (1 `np.argmax()` call). The last `np.argmax()` within the chunk having the global maximum value can be avoided, since we know where the updated values fall into that chunk, so we can save the index, when we compute the maximum in the first place. In summary at maximum `2k + n/c` values are accessed. I still need to think a bit more about it. – bproxauf Jul 20 '22 at 19:11
  • Of course, this works only for chunks for which the chunk maximum gets larger through an update. For the chunks for which it gets smaller or stays equal, we need to loop over all their `c` values to determine the maximum from the non-updated values. – bproxauf Jul 20 '22 at 19:24
  • 1
    Yeah, assuming the old/new values are known (which is the case here), you can skip the computation if the new value if bigger than the max, or if the old values is *strictly* smaller than the max of the chunk. Note that if you have the index of the max and the indexed value is increased, then you can skip almost the entire computation (certainly not the usual case). There are many improvement like that but I think you need to read the full chunk array in the worst case anyway because if all the values are reduced, then the max lies in unchanged values and they must be read here. – Jérôme Richard Jul 20 '22 at 19:44
  • 1
    By the way, conditions are often slow: an unpredictable/misspredicted condition takes about ~3 ns on my machine. So adding conditions should slow down the execution of the worst case but speed up the best cases. In practice it should still worth it though. Also note that the current code is certainly not producing optimal instructions. If the code is not vectorized, then it can certainly be twice faster. That being said, the memory latency also comes into play: the `argmax_from_chunks` is certainly latency bound. So the best is certainly to try optimizations and run them ;) . – Jérôme Richard Jul 20 '22 at 19:50
  • 1
    I updated the answer so to provide hints on vectorization. Theoretically, the code can be made 4 times faster but this is tricky to do which is very challenging. On recent Intel computing server (which supports AVX-512), I expect optimized variants of the codes to be properly vectorized by Numba and so a x8 speed up is theoretically plausible. This reach the limit of the hardware ;) . – Jérôme Richard Jul 20 '22 at 20:55
  • Thanks for your continued commitment to the question, your dedication is really amazing :) I don't know the computer architecture details of my system (it's a workstation at my workplace) and whether/which of these SIMD instructions are supported, but I'm running the code on a 32-core machine. Sorry for my ignorance here, I'm a typical physicist programming... – bproxauf Jul 20 '22 at 21:35
  • Regarding chunk size, I guess as worst-case we need `c*k + n/c` values to be accessed, finding the minimum w.r.t. `c` would suggest `c = sqrt(n/k) ~ 27-28`. The speed boost relative to `c=32` is likely negligible though, I assume. Also the 'best' value of `c` might depend on how exactly the actual speed of `ndarray.argmax()` or `ndarray.max()` depends on the number of elements for small `n`, i.e. where the actual speed departs from the asymptotic linear complexity. And whether a decrease of `n/c` balances an increase in `c*k`. I'll run the `argmax` for different `N` tomorrow, out of curiosity. – bproxauf Jul 20 '22 at 21:41
  • 1
    With a 32-core machine I guess this is a AMD Zen machine (supporting AVX-2). The wide majority of modern computing server are x86-64 machine support AVX. You can check the support on Linux by typing the command `cat /proc/cpuinfo` (`flags` field). For the chunk size, it is a bit complex but I expect the result to be a bit faster if the chunk size is a multiple of 8 because of cache lines. Division by power of two are also faster (because compiler can just use a shift). This is why developers like power of two :) . IDK if this has a significant performance though. – Jérôme Richard Jul 20 '22 at 22:20
  • It's an Intel Xeon CPU E5-4640. The flags field mentions avx, among some others like sse2, ssse3 etc. Not sure whether this includes AVX-2, AVX-512 or the first AVX, though. – bproxauf Jul 21 '22 at 06:56
  • 1
    You should see `avx2` if it is supported but this processor is too old (2012) for AVX2. It has 8 core (and 16 hardware threads). You probably have two of them so the machine has 16 cores (and 32 hardware threads but not 32 cores). It processor can compute 32 (DP float) values in 10 cycles, that is 3.6 ns so updating 1000 chunks of 32 items should takes at least 3.6 us (optimal). The current Numba implementation is latency bound and should take 96 cycles (pure computation not counting overheads), that is 34 us for 1000 chunks! In practice, it is even 56 us due to overheads which is a bit sad. – Jérôme Richard Jul 21 '22 at 11:44
  • 1
    From what I can see, we do have 32 cores, but I am using the machine together with several other users, so the CPU and RAM loads may vary. I ran your code again, and the `update_max_per_chunk` now actually takes 33 us, precisely the number you described. Could be that this is related, maybe, Idk. Edit: Seems it was the `arr` argument that I had not yet changed. Apparently accessing the variable `b` from the global ones costs the difference of 23 us. I updated the timings reported in the question. – bproxauf Jul 21 '22 at 14:35
  • Regarding the OMP warning for `precompute_max_per_chunk`, here are some timings: `parallel = True, nb.prange`: 101 µs ± 21.4 µs per loop. `parallel = False, nb.prange`: 354 µs ± 1.41 µs per loop. `parallel=True, np.arange`: 741 µs ± 89.3 µs per loop; `parallel=False, np.arange`: 546 µs ± 170 ns per loop. So the suggested solution (`parallel=True, nb.prange`) is the way to go. – bproxauf Jul 21 '22 at 14:43
  • 1
    Ok so the parallel version seems fine to me. Note that `np.arange` is slow as it create a temporary array. Consider using just `range` instead. – Jérôme Richard Jul 21 '22 at 17:25
  • 1
    To easily get an overview of the target machine, you can use `hwloc` and more specifically `lstopo`. [Here](https://access.redhat.com/webassets/avalon/d/Red_Hat_Enterprise_Linux-8-Monitoring_and_managing_system_status_and_performance-en-US/images/7ddd66e302a97d0b710e23e0af9d6524/lstopo.png) is an example of output. PU are hardware threads and Core are true cores. Many tools wrongly reports hardware threads as cores, but not this one. You can also easily see the multiple processors with it and more complex stuff. – Jérôme Richard Jul 21 '22 at 17:32
  • 1
    Interesting. I can reproduce the difference of performance for the use of `b`. The global access can indeed cause less efficient optimizations. Note this effect also happens in pure Python codes but I did not expect it to happen in Numba. – Jérôme Richard Jul 21 '22 at 17:53
  • I played around with saving the `range(1,32)` in a variable outside the outer for-loop, as the object then would not have to be created a 1000 times (and I assume creating the range takes more time than accessing it), but somehow the performance did not change at all. What is the reason for this? – bproxauf Jul 21 '22 at 18:52
  • 1
    This is surprising unless Numba is smart enough to optimize that. Note that `range(1,32)` is optimized by Numba and basically cost nothing since no object are creating in Numba (as opposed to with CPython). I made a quick check with a basic loop and it is slower on my machine (with Numba version 0.55.2) : my Numba does not optimize the operation. – Jérôme Richard Jul 21 '22 at 18:57
  • For Sequence CLEAN (Bose, R., 2002), it is required to find the indices of the `K` largest values (instead of the index of the maximum, with `K` being a small number, say < 20) in each iteration and then recursively call the algorithm itself (spanning a tree). If we did not have that large arrays or that many iterations, I would use `np.argpartition()`, see https://stackoverflow.com/a/23734295/5269892. But again, computation time becomes an issue when going over the full array. Can your solution be easily adjusted for that case and what would be the most efficient way to do so? Thanks a lot! – bproxauf Jul 22 '22 at 12:17
  • 1
    Unfortunately, not the last solution. It could be for 2 value with some tricks but then the computational time would not be reasonable (nor the code size). The tree-based solution still works though and should still be fast assuming the possible conversions are not a problem. Note that a binary tree can be stored as an array if needed but it can take twice the space in the worst case due to the balancing. Heaps can maybe help too but I am not sure the invariants are conserved with the removal of any items. – Jérôme Richard Jul 22 '22 at 16:28
  • I posted a separate question (https://stackoverflow.com/q/73106139/5269892) for the general case (`K > 1`), since this post is mostly about the case `K = 1`. If you shared your expertise about the implementation of your binary search tree in practice, I would greatly appreciate it. In any case, I am deeply grateful for your help up to this point. Thanks a lot! – bproxauf Jul 25 '22 at 08:36