5

I am using numba and I would like to know if an overflow has occurred when I multiply two integers. Say the integers are positive for simplicity.

I have written the following function to try and achieve this:


from numba import njit
import numpy as np

@njit
def safe_mul(a, b):
    c = a * b

    print('a: ', a)
    print('b: ', b)
    print('c: ', c)
    
    print('c//a: ', c//a)
    print('0//a: ', 0//a)
    
    if c // a != b or c // b != a:
        # do something else or raise error
        raise ValueError()
    return c

@njit
def safe_mul_2(a, b):
    if (np.log2(np.abs(a)) + np.log2(np.abs(b))) >= 63:
        \# do something else or raise error
        raise ValueError()

    return a * b

print(safe_mul(2**21, 2**51))

The code prints:


a:  2097152
b:  2251799813685248
c:  0
c//a:  2251799813685248
0//a:  0
0

safe_mul does not catch the overflow when 2**21 and 2**51 are passed in. Perhaps numba is compiling out the integer division since it knows c has just been multiplied by what it is being divided by? I am not sure about this since when you enter something where the arguments are not both powers of 2 then the error is caught.

safe_mul_2 does catch the error, and is surprisingly not much slower than safe_mul (when the prints are removed). I would like to know what is happening in safe_mul and if something that is faster than safe_mul_2 can be written.

Patrick
  • 53
  • 5
  • 1
    I believe signed integer overflow in Numba is outright undefined behavior, like in C: anything can happen, including this. – user2357112 Mar 14 '23 at 01:27

2 Answers2

0

Numba inherits from the behaviour of the C language when it comes to integer overflows. Overflows on signed integer operation cause an undefined result (not just platform dependent). The JIT compiler (LLVM-JIT) is free to optimize the code assuming this does not happens, and its does! Indeed, mathematically, c//a = (a*b)//a = b, so the compiler just replace the computation of c//a by b even though in practice c is 0 due to an overflow that should not happen so we expect to see c//a = 0 too.

One way to remove this undefined result is to use unsigned integers which are well defined. Indeed, unsigned integers do not overflow, but wrap around. Based on this, you can implement the detection of an overflow safely using a division like you did.

Note divisions are quite expensive, even on modern processors. The bigger the number, the more expensive the operation. If you can, please consider using 32-bit integers instead of 64-bit ones. Pretty recent processors like Intel Connon-Lake/Ice-Lake/Alder-Lake, AMD Zen3/Zen4 and the Apple M1 compute integers divisions significantly faster (at the expense of significantly more transistors), but the 64-bit division is still more expensive on such processors.

You can explicitly cast a number to a given type using for example np.uint64(a). Thus, you can do c = np.uint64(a) * np.uint64(b) and then c // np.uint64(a). You can compute the sign manually if you deal with negative numbers since isNegative(c) = isNegative(a) xor isNegative(b).

For more information about the undefined result in C, please read Why is unsigned integer overflow defined behavior but signed integer overflow isn't?

Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
0

On the Numba discourse site, the user sschaer was able to provide a very fast solution. See this: https://numba.discourse.group/t/catch-overflow-in-integer-multiplication/1827 for the original discussion. The solution is copied here with sschaer's permission.

LLVM has integer operations that return an overflow bit and they are already known to Numba's intermediate representation. The following functions exposes the multiplication operations:

import numpy as np
from numba import njit, types
from numba import TypingError
from numba.extending import intrinsic

@intrinsic
def mul_with_overflow(typingctx, a, b):
    if not (isinstance(a, types.Integer) and isinstance(b, types.Integer)):
        raise TypingError("both arguments must be integers")
    if a.signed != b.signed:
        raise TypingError("can only multiply integers of equal signedness")
    
    if a.signed:
        ext = lambda builder, a, b: builder.sext(a, b)
        mul = lambda builder, a, b: builder.smul_with_overflow(a, b)
    else:
        ext = lambda builder, a, b: builder.zext(a, b)
        mul = lambda builder, a, b: builder.umul_with_overflow(a, b)

    retint_ty = max(a, b, key=lambda ty: ty.bitwidth)
    sig = types.Tuple([retint_ty, types.boolean])(a, b)
    
    def codegen(context, builder, signature, args):
        int_ty = context.get_value_type(retint_ty)
        a = ext(builder, args[0], int_ty)
        b = ext(builder, args[1], int_ty)
        prod_and_flag = mul(builder, a, b)
        return prod_and_flag
    
    return sig, codegen

@njit
def foo(a, b):
    return mul_with_overflow(a, b)

foo(np.int8(1), np.int32(-2)) # returns int32
foo(np.uint8(1), np.uint16(2)) # returns uint16
foo(np.int32(2), np.int64(np.iinfo(np.int64).max)) # overflow
foo(np.uint64(1), np.int64(2)) # error
foo(np.uint64(1), np.float32(2)) # error
Patrick
  • 53
  • 5