3

I'm trying to play with Numba and tried to run this code (mapping a big list):

from numba import njit, jit
from datetime import datetime


big_list = [(i, i + 10000) for i in range(1, 100000000)]

#Just a number of arithmetic operations (Using Numba).
@njit(cache=True)
def just_calc_jit(row):
    exp_1 = row[1] / row[0]
    exp_2 = (row[0] + 10000) / row[1]
    exp_3 = (exp_2 - row[0]) / exp_1
    exp_3 *= exp_3
    return exp_3

# Same function without Numba.
def just_calc(row):
    exp_1 = row[1] / row[0]
    exp_2 = (row[0] + 10000) / row[1]
    exp_3 = (exp_2 - row[0]) / exp_1
    exp_3 *= exp_3
    return exp_3

# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
    start = datetime.now()
    result = list(map(just_calc, big_list))
    execution_time = datetime.now() - start
    print("execution time:", execution_time)

    start = datetime.now()
    result = list(map(just_calc, big_list))
    execution_time = datetime.now() - start
    print("execution time jit:", execution_time)

This is the output of the script (You can see the execution time with and without using Numba 5 times for each):

execution time: 0:00:17.643550
execution time jit: 0:00:19.780514
execution time: 0:00:19.072673
execution time jit: 0:00:18.961395
execution time: 0:00:20.567786
execution time jit: 0:00:20.119370
execution time: 0:00:21.254276
execution time jit: 0:00:20.034304
execution time: 0:00:20.219750
execution time jit: 0:00:19.237941

What am I missing/doing wrong?

I'mahdi
  • 23,382
  • 5
  • 22
  • 30
idan ahal
  • 707
  • 8
  • 21
  • you need to insert `for-loop` in your functions and use `parallel=True` of numba, you can get a better result. you insert `for-loop` out of your functions. – I'mahdi Jun 15 '22 at 12:42

1 Answers1

4

I change two things in your code and get a better result in run-time with numba: (In the first run of the numba functions, we get a bad result.)

  1. loop over rows of big_list in the functions.
  2. Input big_list as numpy.asarray to functions.
from datetime import datetime
import numba as nb
import numpy as np


big_list = [(i, i + 10_000) for i in range(1, 100_000)]

#Just a number of arithmetic operations (Using Numba).
@nb.njit(parallel = True)
def just_calc_jit(arr):
    num_row = len(arr)
    res = np.empty((num_row))
    for i in nb.prange(num_row):
        exp_1 = arr[i][1] / arr[i][0]
        exp_2 = (arr[i][0] + 10000) / arr[i][1]
        exp_3 = (exp_2 - arr[i][0]) / exp_1
        exp_3 *= exp_3
        res[i] = exp_3
    return res

# Same function without Numba.
def just_calc(arr):
    num_row = len(arr)
    res = np.empty((num_row))
    for i in range(num_row):
        exp_1 = arr[i][1] / arr[i][0]
        exp_2 = (arr[i][0] + 10000) / arr[i][1]
        exp_3 = (exp_2 - arr[i][0]) / exp_1
        exp_3 *= exp_3
        res[i] = exp_3
    return res

# Prints execution times (with and without Numba) 5 times for every function.
for i in range(5):
    start = datetime.now()
    result = just_calc(np.asarray(big_list))
    execution_time = datetime.now() - start
    print("execution time:", execution_time)

    start = datetime.now()
    result = just_calc_jit(np.asarray(big_list))
    execution_time = datetime.now() - start
    print("execution time jit:", execution_time)

Output: (Benchmark on colab)

execution time: 0:00:00.323675
execution time jit: 0:00:00.639346
execution time: 0:00:00.237574
execution time jit: 0:00:00.046685
execution time: 0:00:00.222264
execution time jit: 0:00:00.048550
execution time: 0:00:00.223323
execution time jit: 0:00:00.049903
execution time: 0:00:00.222570
execution time jit: 0:00:00.049623
I'mahdi
  • 23,382
  • 5
  • 22
  • 30
  • First of all, thank you for the answer. But I have a problem with that, I want to use Numba for accelerating pyspark which means I'll use 'map' or 'mapPartitions' functions. So Numba can't help me with that? – idan ahal Jun 15 '22 at 13:44
  • And why didn't you use the (cache=True) option? – idan ahal Jun 15 '22 at 13:45
  • 1
    @idanahal,Welcome. A_Q_1: Do you hear about [`JAX`](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), In JAX, we can use map and get a very good run-time [(here)](https://stackoverflow.com/questions/69099847/jax-vectorization-vmap-and-or-numpy-vectorize). But I don't know about supporting in `pyspark`. A_Q_2: You don't need to use `cache=True` in your problem. You can read about `cache=True` [here](https://stackoverflow.com/questions/59427775/numba-cache-true-has-no-effect#:~:text=The%20point%20of%20using%20cache,significantly%20reduce%20the%20run%2Dtime.) – I'mahdi Jun 15 '22 at 14:01
  • So Numba is relevant only when there is a loop in the function? In pypsark I give several workers to work on partitioned data using its own map function (In general). – idan ahal Jun 15 '22 at 14:03
  • 1
    @idanahal, Yes, In Numba, We can get a better result in for-loop and if we can use parallelism. – I'mahdi Jun 15 '22 at 14:13
  • @I'mahdi Can `fastmath=True` improve the performance? – Ali_Sh Jun 15 '22 at 21:32
  • 1
    @Ali_Sh, thanks for the recommendation. I check this code with your recommendation on [`colab`](https://colab.research.google.com/) and get run_time like in the code and with `%%time` on the cell and `%timeit` on each line and don't see any improvement. – I'mahdi Jun 16 '22 at 07:48