I am using numba and I would like to know if an overflow has occurred when I multiply two integers. Say the integers are positive for simplicity.
I have written the following function to try and achieve this:
from numba import njit
import numpy as np
@njit
def safe_mul(a, b):
c = a * b
print('a: ', a)
print('b: ', b)
print('c: ', c)
print('c//a: ', c//a)
print('0//a: ', 0//a)
if c // a != b or c // b != a:
# do something else or raise error
raise ValueError()
return c
@njit
def safe_mul_2(a, b):
if (np.log2(np.abs(a)) + np.log2(np.abs(b))) >= 63:
\# do something else or raise error
raise ValueError()
return a * b
print(safe_mul(2**21, 2**51))
The code prints:
a: 2097152
b: 2251799813685248
c: 0
c//a: 2251799813685248
0//a: 0
0
safe_mul
does not catch the overflow when 2**21
and 2**51
are passed in. Perhaps numba is compiling out the integer division since it knows c
has just been multiplied by what it is being divided by? I am not sure about this since when you enter something where the arguments are not both powers of 2 then the error is caught.
safe_mul_2
does catch the error, and is surprisingly not much slower than safe_mul
(when the prints are removed). I would like to know what is happening in safe_mul
and if something that is faster than safe_mul_2
can be written.