1

This is a contrived test case but, hopefully, it can suffice to convey the point and ask the question. Inside of a Numba njit function, I noticed that it is very costly to assign a locally computed value to an array element. Here are two example functions:

from numba import njit
import numpy as np

@njit
def slow_func(x, y):
    result = y.sum()
    
    for i in range(x.shape[0]):
        if x[i] > result:
            x[i] = result
        else:
            x[i] = result

@njit
def fast_func(x, y):
    result = y.sum()
    
    for i in range(x.shape[0]):
        if x[i] > result:
            z = result
        else:
            z = result

if __name__ == "__main__":
    x = np.random.rand(100_000_000)
    y = np.random.rand(100_000_000)

    %timeit slow_func(x, y)  # 177 ms ± 1.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    %timeit fast_func(x, y)  # 407 ns ± 12.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

I understand that the two functions aren't quite doing the same thing but let's not worry about that for now and stay focused on the "slow assignment". Also, due to Numba's lazy initialization, the timing above has been re-run post JIT-compiling. Notice that both functions are assigning result to either x[i] or to z and the number of assignments are the same in both cases. However, the assignment of result to z is substantially faster. Is there a way to make the slow_func as fast as the fast_func?

slaw
  • 6,591
  • 16
  • 56
  • 109
  • Not a compiler expert, but I wouldn't be surprised if most of your example functions get optimized away. Compilers are pretty smart these days. In particular, the assignment to `z` has no effect so may be dropped by the jit. – Paul Panzer Jul 23 '20 at 19:40
  • 1
    I just compared your `fast_func` against a function that does nothing and returns `None`. They take the same time to execute. – Paul Panzer Jul 23 '20 at 19:56
  • @PaulPanzer I think you may be right. If a simply return `z` at the end of `fast_func` then the timing is about the same as `slow_func`. Nonetheless, I didn't expect array assignment to be so slow – slaw Jul 23 '20 at 20:07

1 Answers1

0

As @PaulPanzer already has pointed out, your fast function does nothing once optimized - so what you see is basically the overhead of calling a numba-function.

The interesting part is, that in order to do this optimization, numba must be replacing np.sum with its own sum-implementation - otherwise the optimizer would not be able to throw the call to this function away, as it cannot look into the implementation of np.sum and must assume that there are side effects from calling this function.

Let's measure only the summation with numba:

from numba import njit
@njit
def only_sum(x, y):
    return y.sum()

%timeit only_sum(y,x) 
# 112 ms ± 623 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

. Well, that is disappointing: I know my machine can do more than 10^9 addition per second and to read up to 13GB/s from RAM (there are about 0.8GB data, so it doesn't fit the cache), which mean I would expect the summation to use between 60-80ms.

And if I use the numpy's version, it really does:

%timeit y.sum()
# 57 ms ± 444 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

That sounds about right! I assume, numba doesn't use the pairwise addition and thus is slower (if the RAM is fast enough to be the bottleneck) and less precise than numpy's version.

If we just look at the writing of the values:

@njit
def only_assign(x, y):
    res=y[0]
    for i in range(x.shape[0]):
        x[i]=res

%timeit only_assign(x,y)
85.2 ms ± 417 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

so we see it is really slower than reading. The reason for that (and how it can be fixed) is explained in this great answer: the update of caches which numba (rightly?) doesn't bypass.


In a nutshell: While assigning of values in numba isn't really slow (even if it could be speed-up by ussing non-temporal memory accesses), the really slow part is the summation (which seems not to use the pairwise summation) - it is inferior to the numpy's version.

ead
  • 32,758
  • 6
  • 90
  • 153