One option is to use multiprocessing
(i.e. use processes instead of threads). Here's an example that uses the map
function of the multiprocessing.Pool
class.
The function solve
takes a set of initial conditions and returns a solution generated by odeint
. The "serial" version of the code in the main section calls solve
repeatedly, once for each set of initial conditions in ics
. The "multiprocessing" version uses the map
function of a multiprocessing.Pool
instance to run several processes simultaneously, each calling solve
. The map
function takes care of doling out the arguments to solve
.
My computer has four cores, and as I increase num_processes
, the speedup maxes out at about 3.6.
from __future__ import division, print_function
import sys
import time
import multiprocessing as mp
import numpy as np
from scipy.integrate import odeint
def lorenz(q, t, sigma, rho, beta):
x, y, z = q
return [sigma*(y - x), x*(rho - z) - y, x*y - beta*z]
def solve(ic):
t = np.linspace(0, 200, 801)
sigma = 10.0
rho = 28.0
beta = 8/3
sol = odeint(lorenz, ic, t, args=(sigma, rho, beta), rtol=1e-10, atol=1e-12)
return sol
if __name__ == "__main__":
ics = np.random.randn(100, 3)
print("multiprocessing:", end='')
tstart = time.time()
num_processes = 5
p = mp.Pool(num_processes)
mp_solutions = p.map(solve, ics)
tend = time.time()
tmp = tend - tstart
print(" %8.3f seconds" % tmp)
print("serial: ", end='')
sys.stdout.flush()
tstart = time.time()
serial_solutions = [solve(ic) for ic in ics]
tend = time.time()
tserial = tend - tstart
print(" %8.3f seconds" % tserial)
print("num_processes = %i, speedup = %.2f" % (num_processes, tserial/tmp))
check = [(sol1 == sol2).all()
for sol1, sol2 in zip(serial_solutions, mp_solutions)]
if not all(check):
print("There was at least one discrepancy in the solutions.")
On my computer, the output is:
multiprocessing: 6.904 seconds
serial: 24.756 seconds
num_processes = 5, speedup = 3.59