2

I have a function to calculate the log gamma function that I am decorating with numba.njit.

import numpy as np
from numpy import log
from scipy.special import gammaln
from numba import njit

coefs = np.array([
    57.1562356658629235, -59.5979603554754912,
    14.1360979747417471, -0.491913816097620199,
    .339946499848118887e-4, .465236289270485756e-4,
    -.983744753048795646e-4, .158088703224912494e-3,
    -.210264441724104883e-3, .217439618115212643e-3,
    -.164318106536763890e-3, .844182239838527433e-4,
    -.261908384015814087e-4, .368991826595316234e-5
])

@njit(fastmath=True)
def gammaln_nr(z):
    """Numerical Recipes 6.1"""
    y = z
    tmp = z + 5.24218750000000000
    tmp = (z + 0.5) * log(tmp) - tmp
    ser = np.ones_like(y) * 0.999999999999997092

    n = coefs.shape[0]
    for j in range(n):
        y = y + 1
        ser = ser + coefs[j] / y

    out = tmp + log(2.5066282746310005 * ser / z)
    return out

When I use gammaln_nr for a large array, say np.linspace(0.001, 100, 10**7), my run time is about 7X slower than scipy (see code in appendix below). However, if I run for any individual value, my numba function is always about 2X faster. How is this happening?

z = 11.67
%timeit gammaln_nr(z)
%timeit gammaln(z)
>>> 470 ns ± 29.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
>>> 1.22 µs ± 28.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

My intuition is that if my function is faster for one value, it should be faster for an array of values. Of course, this may not be the case because I don't know whether numba is using SIMD instructions or some other sort of vectorization, whereas scipy may be.

Appendix


import matplotlib.pyplot as plt
import seaborn as sns

n_trials = 8
scipy_times = np.zeros(n_trials)
fastats_times = np.zeros(n_trials)

for i in range(n_trials):
    zs = np.linspace(0.001, 100, 10**i) # evaluate gammaln over this range

    # dont take first timing - this is just compilation
    start = time.time()
    gammaln_nr(zs)
    end = time.time()

    start = time.time()
    gammaln_nr(zs)
    end = time.time()
    fastats_times[i] = end - start

    start = time.time()
    gammaln(zs)
    end = time.time()
    scipy_times[i] = end - start

fig, ax = plt.subplots(figsize=(12,8))
sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times, label="numba");
sns.lineplot(np.logspace(0, n_trials-1, n_trials), scipy_times, label="scipy");
ax.set(xscale="log");
ax.set_xlabel("Array Size", fontsize=15);
ax.set_ylabel("Execution Time (s)", fontsize=15);
ax.set_title("Execution Time of Log Gamma");

enter image description here

PyRsquared
  • 6,970
  • 11
  • 50
  • 86
  • 1
    It is probably the for-loop that is slowing you down. The scipy function is AOT compiled ([source](https://github.com/scipy/scipy/tree/master/scipy/special)) wheres numba uses JIT. Not sure on the efficiency of JIT but your results imply the former is faster. – cvanelteren Mar 07 '19 at 16:21
  • @GlobalTraveler both good points, thanks. But I time my function twice; the first time is for compilation and I record the second timing (since its now been compiled) as advised on the numba docs. So compilation shouldn't be an issue in theory. Although it's true that my for loop could be slowing me down a lot. – PyRsquared Mar 07 '19 at 16:46
  • Most (and maybe all?) NumPy/SciPy numeric functions convert their input to NumPy arrays before computing anything. When you pass a NumPy array to these functions, the conversion via `np.array` or `np.asarray` or `np.asanyarray` is very quick, but if the input is a scalar or list, then the internal conversion to a NumPy array [can take a significant fraction of the total run time](https://stackoverflow.com/a/3651058/190597) of the function. – unutbu Mar 07 '19 at 17:19
  • 1
    Compare, for instance, the time it takes to merely convert 11.67 to a numpy array: `%timeit np.array(11.67)` (~168ns) versus the time it takes to commute `gammaln(11.67)` using `math`: `%timeit math.log(math.gamma(11.67))` (214ns), versus `%timeit special.gammaln(11.67)` (~567 ns). So it's perhaps not surprising that `special.gammaln(11.67)` is slow versus `gammaln_nr(11.67)`, but that the scalar benchmark has little bearing on array-based benchmarks. – unutbu Mar 07 '19 at 17:19

1 Answers1

4

Implementing gammaln in Numba

It can be quite some work to reimplement some often used functions, not only to reach the performance, but also to get a well defined level of precision. So the direct way would be to simply wrap a working implementation.

In case of gammaln scipy- calls a C-implemntation of this function. Therefore the speed of the scipy-implementation also depends on the compiler and compilerflags used when compiling the scipy dependencies.

It is also not very suprising that the performance results for one value can differ quite a lot from the results of larger arrays. In the first case the calling overhead (including type conversions, input checking,...) dominates, in the second case the performance of the implementation gets more and more important.

Improving your implementation

  • Write explicit loops. In Numba vectorized operations are expanded to loops and after that Numba tries to join the loops. It is often better to write out and join this loops manually.
  • Think of the differences of basic arithmetic implementations. Python always checks for a division by 0 and raises an exception in such a case, which is very costly. Numba also uses this behaviour by default, but you can also switch to Numpy-error checking. In this case a division by 0 results in NaN. The way NaN and Inf -0/+0 is handled in further calculations is also influenced by the fast-math flag.

Code

import numpy as np
from numpy import log
from scipy.special import gammaln
from numba import njit
import numba as nb

@njit(fastmath=True,error_model='numpy')
def gammaln_nr(z):
    """Numerical Recipes 6.1"""
    #Don't use global variables.. (They only can be changed if you recompile the function)
    coefs = np.array([
    57.1562356658629235, -59.5979603554754912,
    14.1360979747417471, -0.491913816097620199,
    .339946499848118887e-4, .465236289270485756e-4,
    -.983744753048795646e-4, .158088703224912494e-3,
    -.210264441724104883e-3, .217439618115212643e-3,
    -.164318106536763890e-3, .844182239838527433e-4,
    -.261908384015814087e-4, .368991826595316234e-5])

    out=np.empty(z.shape[0])


    for i in range(z.shape[0]):
      y = z[i]
      tmp = z[i] + 5.24218750000000000
      tmp = (z[i] + 0.5) * np.log(tmp) - tmp
      ser = 0.999999999999997092

      n = coefs.shape[0]
      for j in range(n):
          y = y + 1.
          ser = ser + coefs[j] / y

      out[i] = tmp + log(2.5066282746310005 * ser / z[i])
    return out

@njit(fastmath=True,error_model='numpy',parallel=True)
def gammaln_nr_p(z):
    """Numerical Recipes 6.1"""
    #Don't use global variables.. (They only can be changed if you recompile the function)
    coefs = np.array([
    57.1562356658629235, -59.5979603554754912,
    14.1360979747417471, -0.491913816097620199,
    .339946499848118887e-4, .465236289270485756e-4,
    -.983744753048795646e-4, .158088703224912494e-3,
    -.210264441724104883e-3, .217439618115212643e-3,
    -.164318106536763890e-3, .844182239838527433e-4,
    -.261908384015814087e-4, .368991826595316234e-5])

    out=np.empty(z.shape[0])


    for i in nb.prange(z.shape[0]):
      y = z[i]
      tmp = z[i] + 5.24218750000000000
      tmp = (z[i] + 0.5) * np.log(tmp) - tmp
      ser = 0.999999999999997092

      n = coefs.shape[0]
      for j in range(n):
          y = y + 1.
          ser = ser + coefs[j] / y

      out[i] = tmp + log(2.5066282746310005 * ser / z[i])
    return out


import matplotlib.pyplot as plt
import seaborn as sns
import time

n_trials = 8
scipy_times = np.zeros(n_trials)
fastats_times = np.zeros(n_trials)
fastats_times_p = np.zeros(n_trials)

for i in range(n_trials):
    zs = np.linspace(0.001, 100, 10**i) # evaluate gammaln over this range

    # dont take first timing - this is just compilation
    start = time.time()
    arr_1=gammaln_nr(zs)
    end = time.time()

    start = time.time()
    arr_1=gammaln_nr(zs)
    end = time.time()
    fastats_times[i] = end - start

    start = time.time()
    arr_3=gammaln_nr_p(zs)
    end = time.time()
    fastats_times_p[i] = end - start
    start = time.time()

    start = time.time()
    arr_3=gammaln_nr_p(zs)
    end = time.time()
    fastats_times_p[i] = end - start
    start = time.time()

    arr_2=gammaln(zs)
    end = time.time()
    scipy_times[i] = end - start
    print(np.allclose(arr_1,arr_2))
    print(np.allclose(arr_1,arr_3))

fig, ax = plt.subplots(figsize=(12,8))
sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times, label="numba");
sns.lineplot(np.logspace(0, n_trials-1, n_trials), fastats_times_p, label="numba_parallel");
sns.lineplot(np.logspace(0, n_trials-1, n_trials), scipy_times, label="scipy");
ax.set(xscale="log");
ax.set_xlabel("Array Size", fontsize=15);
ax.set_ylabel("Execution Time (s)", fontsize=15);
ax.set_title("Execution Time of Log Gamma");
fig.show()
max9111
  • 6,272
  • 1
  • 16
  • 33
  • that's amazing that you got such a speed up by writing explicit loops. Thanks for the explanation and code! I have to get used to writing explicit loops in python - it feels weird because I'm so used to doing everything in numpy! Thanks again – PyRsquared Mar 08 '19 at 10:39
  • 1
    Please note that setting error_model='numpy' is also very important. But I really recommend to use a well tested implementation for this function. If you are not happy with wrapping the scipy(cephes) implementation, you can also write a few lines of modern fortran (gammaln is a intrinistic function in modern fortran) and create a dll with the fastest compiler you can get and wrap the function to Numba (CFFI or CTypes). Also I would not recommend to use the parallel version. If you have some nested loops where gammaln is called, just parallelize the outer loop if possible. – max9111 Mar 08 '19 at 11:14
  • 1
    My only issue now, is that your implementation works for arrays and not single values. I can't seem to get type checking `(if type(z) in [int, float]...)` to work with numba so that the `gammaln_nr` function is more generalized. I realize this may come at the expense of reducing the performance slightly – PyRsquared Mar 08 '19 at 11:44
  • 1
    I don't know, but handling nd-Arrays differently from scalar values within Numba would be a good question. I guess generated_jit may help. https://numba.pydata.org/numba-doc/dev/user/generated-jit.html – max9111 Mar 08 '19 at 12:24
  • `generated_jit` is exactly what I needed. Much appreciated – PyRsquared Mar 08 '19 at 12:37