Based on the explanation provided here 1, I am trying to use the same idea to speed up the following integral:
import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
print(do_integrate(integrand, 2.)[0])
With the previous reference, I tried to use numba/jit and modify the previous block in the following way:
import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
do_integrate(integrand, 2.)
However, this implementation gives me the error
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.
File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
<source elided>
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
^
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
The error is coming from the fact that I am using fsolve from scipy.optimize inside the integrand function.
I would like to know if there is a workaround this error and if it is possible to use the scipy.optimize.fsolve with numba in this context.