I am trying to use scipy.integrate.ode
or scipy.integrate.odeint
to solve a system of ODE for a large set of (more than a thousand) initial conditions, however it is extremely slow by performing loops, and scipy does not seem to provide options for inputting 2D arrays (stacked by a set of 1D arrays specifying initial conditions), and the vectorized
option of scipy.integrate.solve_ivp
doesn't seem to mean that it accepts 2D arrays of initial conditions (https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html).
I have read a thread asking a similar question (Vectorized SciPy ode solver), one of the answer suggests to use scipy.integrate.odeint
, however it doesn't seem to accept multidimensional arrays either, so it don't understand how to implement this at all. Are there any solutions to speed up process? Other than vectorization, I thought about using parallel computing techniques but I am not familiar with this, and I think it doesn't really speed up the program as significantly as vectorization?
Asked
Active
Viewed 2,163 times
6

Sato
- 1,013
- 1
- 12
- 27
-
1Why would a 2d initial conditions be an improvement? Is the ode function as fast as it could be? That's where you have most control over evaluation speed. – hpaulj Aug 29 '19 at 11:34
-
2For large scale production the scipy codes are rather sub-optimal. Use the bindings to Sundials or DifferentialEquations.jl (there is some overlap) to get access to the most recent and tested implementations. Use JIT or directly a compiled language for the ODE functions. Note that the parallel solution might influence the accuracy of single solutions, esp. if distinct parts of the state space are visited by the solution ensemble. – Lutz Lehmann Aug 29 '19 at 12:13
-
2You don't need a 2D array function: you can formulate the ODE function using a 1D array input and then reshape it in the first line of the function code. As hpaulj says, you can increase speed by working on the ode function. Other things you can try are: reduce the convergence criteria or use an explicit integrator like `RK23` or `RK45` with a minimum step, pass the jacobian – Tarifazo Aug 29 '19 at 12:15
-
As for the "parallel computing" bit, see https://stackoverflow.com/questions/4682429/parfor-for-python (assuming there is no cross-thread comms involved). – Leporello Aug 29 '19 at 14:26
-
Not quite as convenient as a single function call that can accept multiple initial values, but multiprocessing is an option: https://stackoverflow.com/questions/34291639/multiple-scipy-integrate-ode-instances – Warren Weckesser Aug 29 '19 at 15:39
-
If anyone is able to give an answer this would be great! – falematte Mar 03 '21 at 10:15
1 Answers
2
Here are two examples of solving ODEs with a large set of initial conditions in parallel with Python. First, the fastest is using Numba multi-threading with the NumbaLSODA ODE integrator. Here I use Numba to compile all the code, so loops are very fast.
This example takes 0.175 seconds.
from NumbaLSODA import lsoda_sig, lsoda
from matplotlib import pyplot as plt
import numpy as np
import numba as nb
@nb.cfunc(lsoda_sig)
def f(t, u, du, p):
du[0] = u[0]-u[0]*u[1]
du[1] = u[0]*u[1]-u[1]
funcptr = f.address
t_eval = np.linspace(0.0,20.0,201)
np.random.seed(0)
@nb.njit(parallel=True)
def main(n):
u1 = np.empty((n,len(t_eval)), np.float64)
u2 = np.empty((n,len(t_eval)), np.float64)
for i in nb.prange(n):
u0 = np.empty((2,), np.float64)
u0[0] = np.random.uniform(4.5,5.5)
u0[1] = np.random.uniform(0.7,0.9)
usol, success = lsoda(funcptr, u0, t_eval, rtol = 1e-8, atol = 1e-8)
u1[i] = usol[:,0]
u2[i] = usol[:,1]
return u1, u2
u1, u2 = main(10000)
plt.rcParams.update({'font.size': 15})
fig,ax = plt.subplots(1,1,figsize=[7,5])
low, med, high = np.quantile(u1,(.025,.5,.975),axis=0)
ax.plot(t_eval,med)
ax.fill_between(t_eval,low,high,alpha=0.3)
low, med, high = np.quantile(u2,(.025,.5,.975),axis=0)
ax.plot(t_eval,med)
ax.fill_between(t_eval,low,high,alpha=0.3)
plt.show()
Another example with multiprocessing and scipy.integrate.odeint
. This example takes 2.9 seconds (16x slower than NumbaLSODA).
from scipy.integrate import odeint
import numba as nb
import numpy as np
from matplotlib import pyplot as plt
from pathos.multiprocessing import ProcessingPool as Pool
@nb.njit
def f_sp(u, t):
return np.array([u[0]-u[0]*u[1],u[0]*u[1]-u[1]])
t_eval = np.linspace(0.0,20.0,201)
def main(u0):
usol = odeint(f_sp, u0, t_eval, rtol = 1e-8, atol = 1e-8)
return usol[:,0], usol[:,1]
n = 10000
u0_1 = np.random.uniform(4.5,5.5,n).reshape((n,1))
u0_2 = np.random.uniform(0.7,0.9,n).reshape((n,1))
u0_all = np.append(u0_1, u0_2, axis=1)
p = Pool(6)
sol = p.map(main, u0_all)
u1 = np.empty((n,len(t_eval)), np.float64)
u2 = np.empty((n,len(t_eval)), np.float64)
for i in range(n):
u1[i] = sol[i][0]
u2[i] = sol[i][1]
plt.rcParams.update({'font.size': 15})
fig,ax = plt.subplots(1,1,figsize=[7,5])
low, med, high = np.quantile(u1,(.025,.5,.975),axis=0)
ax.plot(t_eval,med)
ax.fill_between(t_eval,low,high,alpha=0.3)
low, med, high = np.quantile(u2,(.025,.5,.975),axis=0)
ax.plot(t_eval,med)
ax.fill_between(t_eval,low,high,alpha=0.3)
plt.show()

nicholaswogan
- 631
- 6
- 13