0

I have a program in Python and I use numba to compile the code to native and run faster.

I want to accelerate the run even further, and implement a cache for function results - if the function is called twice with the same parameters, the first time the calculation will run and return the result and the same time the function will return the result from the cache.

I tried to implement this with a dict, where the keys are tuples containing the function parameters, and the values are the function return values.

However, numba doesn't support dictionaries and the support for global variables is limited, so my solution didn't work.

I can't use a numpy.ndarray and use the indices as the parameters, since some of my parameters are floats.

The problem i that both the function with cached results and and the calling function are compiled with numba (if the calling function was a regular python function, I could cache using just Python and not numba)

How can I implement this result cache with numba?

============================================

The following code gives an error, saying the Memoize class is not recognized

from __future__ import annotations

from numba import njit


class Memoize:
    def __init__(self, f):
        self.f = f
        self.memo = {}

    def __call__(self, *args):
        if args not in self.memo:
            self.memo[args] = self.f(*args)
        #Warning: You may wish to do a deepcopy here if returning objects
        return self.memo[args]


@Memoize
@njit
def bla(a: int, b: float):
    for i in range(1_000_000_000):
        a *= b
    return a


@njit
def caller(x: int):
    s = 0
    for j in range(x):
        s += bla(j % 5, (j + 1) % 5)
    return s


if __name__ == "__main__":
    print(caller(30))

The error:

Untyped global name 'bla': Cannot determine Numba type of <class '__main__.Memoize'>
File "try_numba2.py", line 30:
def caller(x: int):
    <source elided>
    for j in range(x):
        s += bla(j % 5, (j + 1) % 5)
        ^

Changing the order of the decorators for bla gives the following error:

TypeError: The decorated object is not a function (got type <class '__main__.Memoize'>).
user107511
  • 772
  • 3
  • 23
  • 1
    See https://stackoverflow.com/questions/1988804/what-is-memoization-and-how-can-i-use-it-in-python . The decorator needs to be applied first so the Numba function do not care about memoization. Besides, globals are constant in Numba so you need parameters to store that if you want to use Numba, which is not needed here (the benefit of using Numba should be small if the function is expensive, otherwise the benefit of memoization is small anyway) – Jérôme Richard Sep 23 '22 at 08:35
  • I tried applying the `Memoize` decorator to an `njit` function, but I get an error that says the numba does not recognize the `Memoize` class type. – user107511 Sep 23 '22 at 12:39
  • I have heavy functions that `numba` succeeds in accelerating (a lot of calculations over vectors) but they take time so memoizing could help here – user107511 Sep 23 '22 at 12:41
  • Additionally - using special scipy functions with `numba_scipy` takes time, so even if I call a numba_scipy function wrapped with memoization in Python (the caller is pure Python and not numba) it is accelerated – user107511 Sep 23 '22 at 12:41
  • Certainly because the order is not correct as stated in the previous comment. Alternatively, you might need a pure-Python wrapping function but Numba already does that under the hood anyway. A decorator is just a function operating on the following function. – Jérôme Richard Sep 23 '22 at 13:05
  • No, if I change the order, the `@njit` decorator gives an error that it does not recognize the `Memoize` class – user107511 Sep 23 '22 at 15:41
  • This is unexpected. Did you tried the proposed wrapper-based solution? If so, can you provide a minimal working example for this? – Jérôme Richard Sep 23 '22 at 16:18
  • Added a minimal working example – user107511 Sep 23 '22 at 16:45
  • Ha you cannot use @njit in the caller function with the Memoize class. If you want to do that, a jitclass is required but this is an experimental feature and I am not sure it will support this use-case (probably not). Removing `@njit` from the caller funciton fix the problem. Alternatively, a solution could be to check the parameters from the caller function assuming it is OK for the dictionaries to be cleared for every call to the caller function. – Jérôme Richard Sep 24 '22 at 11:46
  • That's the issue, I want to call `caller` several times but keep the cached values between calls – user107511 Sep 25 '22 at 07:11

0 Answers0