I have written a Python function that computes pairwise electromagnetic interactions between a largish number (N ~ 10^3) of particles and stores the results in an NxN complex128 ndarray. It runs, but it is the slowest part of a larger program, taking about 40 seconds when N=900 [corrected]. The original code looks like this:
import numpy as np
def interaction(s,alpha,kprop): # s is an Nx3 real array
# alpha is complex
# kprop is float
ndipoles = s.shape[0]
Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=np.complex128)
I = np.array([[1,0,0],[0,1,0],[0,0,1]])
im = complex(0,1)
k2 = kprop*kprop
for i in range(ndipoles):
xi = s[i,:]
for j in range(ndipoles):
if i != j:
xj = s[j,:]
dx = xi-xj
R = np.sqrt(dx.dot(dx))
n = dx/R
kR = kprop*R
kR2 = kR*kR
A = ((1./kR2) - im/kR)
nxn = np.outer(n, n)
nxn = (3*A-1)*nxn + (1-A)*I
nxn *= -alpha*(k2*np.exp(im*kR))/R
else:
nxn = I
Amat[i,:,j,:] = nxn
return(Amat.reshape((3*ndipoles,3*ndipoles)))
I had never previously used Cython, but that seemed like a good place to start in my effort to speed things up, so I pretty much blindly adapted the techniques I found in online tutorials. I got some speedup (30 seconds vs. 40 seconds), but not nearly as dramatic as I expected, so I'm wondering whether I'm doing something wrong or am missing a critical step. The following is my best attempt at cythonizing the above routine:
import numpy as np
cimport numpy as np
DTYPE = np.complex128
ctypedef np.complex128_t DTYPE_t
def interaction(np.ndarray s, DTYPE_t alpha, float kprop):
cdef float k2 = kprop*kprop
cdef int i,j
cdef np.ndarray xi, xj, dx, n, nxn
cdef float R, kR, kR2
cdef DTYPE_t A
cdef int ndipoles = s.shape[0]
cdef np.ndarray Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=DTYPE)
cdef np.ndarray I = np.array([[1,0,0],[0,1,0],[0,0,1]])
cdef DTYPE_t im = complex(0,1)
for i in range(ndipoles):
xi = s[i,:]
for j in range(ndipoles):
if i != j:
xj = s[j,:]
dx = xi-xj
R = np.sqrt(dx.dot(dx))
n = dx/R
kR = kprop*R
kR2 = kR*kR
A = ((1./kR2) - im/kR)
nxn = np.outer(n, n)
nxn = (3*A-1)*nxn + (1-A)*I
nxn *= -alpha*(k2*np.exp(im*kR))/R
else:
nxn = I
Amat[i,:,j,:] = nxn
return(Amat.reshape((3*ndipoles,3*ndipoles)))