I'm trying to get advantages of multi-processing in python, so did some tests and found multi-processing code runs much slower than plain one. What I do wrong???
Here is the test script:
import numpy as np
from datetime import datetime
from multiprocessing import Pool
def some_func(argv):
x = argv[0]
y = argv[1]
return np.sum(x * y)
def other_func(argv):
x = argv[0]
y = argv[1]
f1 = np.fft.rfft(x)
f2 = np.fft.rfft(y)
CC = np.fft.irfft(f1 * np.conj(f2))
return CC
N = 20000
X = np.random.randint(0, 10, size=(N, N))
Y = np.random.randint(0, 10, size=(N, N))
output_check = np.zeros(N)
D1 = datetime.now()
for k in range(len(X)):
output_check[k] = np.max(some_func((X[k], Y[k])))
print('Plain: ', datetime.now()-D1)
output = np.zeros(N)
D1 = datetime.now()
with Pool(10) as pool: # CPUs
for ind, res in enumerate(pool.imap(some_func, zip(X, Y), chunksize=1)):
output[ind] = np.max(res)
pool.close()
pool.join()
print('Pool: ', datetime.now()-D1)
Output:
Plain: 0:00:00.904062
Pool: 0:00:15.386251
Why so big difference? What consumes the time???
Have 80 CPUs available, tried different pool size and chunksize...
The actual function is more complex (like other_func), with it I get almost the same time for plain and parallel code, but still no speed-up :(
The input is a BIG 3D numpy array, and I need a pairwise convolution of its elements