I'm working on trying to find zeros of a function that I've used numba for using scipy.optimize.fsolve. The function is below. The exact details aren't important, but the jist is that F_curr is a 2D numpy array that stores information on the occupancy of some states. This collision integral "I" is the right-hand side of the differential equation dF/dt = I. I want to find steady state solutions to this differential equation, which I would do with scipy.optimize.fsolve.
@njit
def collision_integral(F_curr, mu, beta, Omega, G_matrix, momenta_arr, m_cutoff, floquet_spectrum_dict, num_bands = 2):
num_k_pts = momenta_arr.shape[0]
I_curr = np.zeros(F_curr.shape, dtype = numba.double)
for k1_idx in range(0, num_k_pts):
for alpha in range(0, num_bands):
# at this level, we are about to set the value of the collision integral at index k, alpha
temp = 0
for k2_idx in range(0, num_k_pts):
for alpha_p in range(0, num_bands):
for n in range(0, 2*m_cutoff + 1):
# this is the G term
G_term = np.abs(G_matrix[k1_idx, k2_idx, alpha, alpha_p, n])**2
# this is to compute the middle curly bracket terms
q = momenta_arr[k1_idx] - momenta_arr[k2_idx]
omega = phonon_dispersion(q, 1)
phonon_occupancy = bose_einstein(omega + 0.01, mu, beta)
F_term_1 = F_curr[k2_idx, alpha_p] * (1 - F_curr[k1_idx, alpha]) * phonon_occupancy
F_term_1 -= F_curr[k1_idx, alpha] * (1 - F_curr[k2_idx, alpha_p]) * (1 + phonon_occupancy)
F_term_2 = F_curr[k2_idx, alpha_p] * (1 - F_curr[k1_idx, alpha]) * (1 + phonon_occupancy)
F_term_2 -= F_curr[k1_idx, alpha] * (1 - F_curr[k2_idx, alpha_p]) * phonon_occupancy
# this is the last delta function term to enforce energy conservation
# because of how discretized my k points are, I will use a larger broadening paramter
delta_1 = delta1D(floquet_spectrum_dict[k1_idx, alpha] - floquet_spectrum_dict[k2_idx, alpha_p]
- omega - (n - m_cutoff) * Omega, eta = 0.02)
delta_2 = delta1D(floquet_spectrum_dict[k2_idx, alpha_p] - floquet_spectrum_dict[k1_idx, alpha]
- omega + (n - m_cutoff) * Omega, eta = 0.02)
temp += G_term*(F_term_1 * delta_1 + F_term_2 * delta_2)
# update the collision integral once the triple sum is performed
I_curr[k1_idx, alpha] = np.real(temp)
#print(f"the point {k1_idx, alpha} has been updated")
return I_curr
This function runs fine on its own if I just initialize some parameters and run the function.
However, when I try and run fsolve to find zeros of the collision integral function, which happens in this code block:
def update_procedure(F_init, args):
F_final = scipy.optimize.fsolve(collision_integral, F_init, args = args)
return F_final
I get the following error:
TypingError Traceback (most recent call last)
Input In [38], in <cell line: 9>()
5 F_init[:,0] = 1
6 args = (mu, beta, Omega, G_matrix, momenta_arr, m_cutoff, floquet_spectrum_dict, 2)
----> 9 F_final = update_procedure(F_init, args)
Input In [37], in update_procedure(F_init, args)
1 def update_procedure(F_init, args):
2
3 #F_final, infodict, ier, mesg = scipy.optimize.fsolve(collision_integral, F_init, args)
5 print("got here")
----> 7 F_final = scipy.optimize.fsolve(collision_integral, F_init, args = args)
9 return F_final
File ~\anaconda3\lib\site-packages\scipy\optimize\minpack.py:160, in fsolve(func, x0, args, fprime, full_output, col_deriv, xtol, maxfev, band, epsfcn, factor, diag)
49 """
50 Find the roots of a function.
51
(...)
150
151 """
152 options = {'col_deriv': col_deriv,
153 'xtol': xtol,
154 'maxfev': maxfev,
(...)
157 'factor': factor,
158 'diag': diag}
--> 160 res = _root_hybr(func, x0, args, jac=fprime, **options)
161 if full_output:
162 x = res['x']
File ~\anaconda3\lib\site-packages\scipy\optimize\minpack.py:226, in _root_hybr(func, x0, args, jac, col_deriv, xtol, maxfev, band, eps, factor, diag, **unknown_options)
224 if not isinstance(args, tuple):
225 args = (args,)
--> 226 shape, dtype = _check_func('fsolve', 'func', func, x0, args, n, (n,))
227 if epsfcn is None:
228 epsfcn = finfo(dtype).eps
File ~\anaconda3\lib\site-packages\scipy\optimize\minpack.py:24, in _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape)
22 def _check_func(checker, argname, thefunc, x0, args, numinputs,
23 output_shape=None):
---> 24 res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
25 if (output_shape is not None) and (shape(res) != output_shape):
26 if (output_shape[0] != 1):
Input In [35], in collision_integral(F_curr, mu, beta, Omega, G_matrix, momenta_arr, m_cutoff, floquet_spectrum_dict, num_bands)
6 for k1_idx in range(0, num_k_pts):
7 for alpha in range(0, num_bands):
8 #print(f"starting {k1_idx}, {alpha}")
9 # at this level, we are about to set the value of the collision integral at index k, alpha
---> 10 temp = collision_integral_entry(k1_idx, alpha, F_curr,
11 mu, beta, Omega, G_matrix, momenta_arr, m_cutoff, floquet_spectrum_dict,
12 num_bands)
14 # update the collision integral once the triple sum is performed
15 I_coll[k1_idx, alpha] = np.real(temp)
File ~\anaconda3\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File ~\anaconda3\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(array(float64, 1d, C), UniTuple(int64 x 2))
There are 22 candidate implementations:
- Of which 20 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 1d, C), UniTuple(int64 x 2))':
No match.
- Of which 2 did not match due to:
Overload in function 'GetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 166.
With argument(s): '(array(float64, 1d, C), UniTuple(int64 x 2))':
Rejected as the implementation raised a specific error:
NumbaTypeError: cannot index array(float64, 1d, C) with 2 indices: UniTuple(int64 x 2)
raised from C:\Users\Brandon\anaconda3\lib\site-packages\numba\core\typing\arraydecl.py:88
During: typing of intrinsic-call at C:\Users\Brandon\AppData\Local\Temp\ipykernel_13188\2614269881.py (19)
File "..\..\AppData\Local\Temp\ipykernel_13188\2614269881.py", line 19:
<source missing, REPL/exec in use?>
Which looks as though there's an issue with accessing/setting an entry in an array via numba.
One important thing about the function: I benchmarked my code with and without numba, and numba is giving ~50x performance increase - I don't think that un-jitting my code is in the realm of possibility without some serious rewrite.
The first thing I tried doing was running a similar notebook with a jitted function inside the scipy.fsolve function, and this ran fine. The problem doesn't seem to be numba'd functions inside fsolve on its own:
@njit
def my_func(a):
return np.vdot(a, np.ones((len(a)))) * a - np.ones(len(a))
def solver(initial_guess):
return scipy.optimize.fsolve(my_func, initial_guess)
initial_guess = np.ones(10)
x = solver(initial_guess)
print(x)
Then, I was wondering if scipy might be messing with datatypes, resulting in an error when I tried setting the array entry in the line
I_curr[k1_idx, alpha] = np.real(temp)
The reason I thought this would be the case is because I'd ran into errors before when, for example, initializing np.zeros() in a numba'd function, which were resolved by adding something along the lines of "dtype = numba.double". I tried using a wrapper function to fix datatypes with something like ".astype(numba.double)" , but that doesn't seem to work either.
I then thought it might have been an issue with setting array entries, so I split the numba function into two parts, where the assignment "I_curr[k1_idx, alpha] = np.real(temp)" was not in a numba function, but the inner-most three for loops were. No array assignment was happening inside a numba function, but this also didn't resolve the issue. It seems to be a problem with accessing?
At this point I'm really not sure what to do or where the exact issue is arising. Any help would be greatly appreciated!