I am testing numba performance on some function that takes a numpy
array, and compare:
import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings
warnings.simplefilter('ignore', category=NumbaWarning)
@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
class Main(object):
def __init__(self) -> None:
super().__init__()
self.mat = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)
def my_run(self):
st = time.time()
trace = 0.0
for i in range(self.mat.shape[0]):
trace += np.tanh(self.mat[i, i])
res = self.mat + trace
print('Python Diration: ', time.time() - st)
return res
def jit_run(self):
st = time.time()
res = go_fast(self.mat)
print('Jit Diration: ', time.time() - st)
return res
obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()
The output is:
Python Diration: 0.2164750099182129
Jit Diration: 0.5367801189422607
How can I obtain an enhance version of this example ?