1

I'm writing a script that tracks the shifts of a sample by estimating the displacement of an ensemble of particles. The first implementation, in Python, works alright, but it takes too long for a large amount of samples. To combat this, I tried rewriting the method in Cython, but as this was my first time ever using it, I can't seem to get any performance increases. I know 3D FFTs exist and are often faster than looped 2D FFTs, but for this instance, they take too much memory and or slower than for-loops.

Python function:

import numpy as np
from scipy.fft import fftshift
import pyfftw
def python_corr(frame_a, frame_b):
    DTYPEf = 'float32'
    DTYPEc = 'complex64'
    
    k = frame_a.shape[0]
    m = frame_a.shape[1] # size y of 2d sample
    n = frame_a.shape[2] # size x of 2d sample
    
    fs = [m,n] # sample shape
    bs = [m,n//2+1] # rfft sample shape
    
    corr = np.zeros([k,m,n], DTYPEf) # out
    
    fft_forward = pyfftw.builders.rfft2(
        pyfftw.empty_aligned(fs, dtype = DTYPEf),
        axes = [-2,-1],
    )
    
    fft_backward = pyfftw.builders.irfft2(
        pyfftw.empty_aligned(bs, dtype = DTYPEc),
        axes = [-2,-1],
    )
    for ind in range(k): # looping over 2D samples
        window_a = frame_a[ind,:,:]
        window_b = frame_b[ind,:,:]
        corr[ind,:,:] = fftshift( # cross correlation via FFT algorithm
            np.real(fft_backward(
                np.conj(fft_forward(window_a))*fft_forward(window_b)
            )),
            axes = [-2,-1]
        )
    return corr

Cython function:

import numpy as np
from scipy.fft import fftshift
import pyfftw

cimport numpy as np
np.import_array()
cimport cython

DTYPEf = np.float32
ctypedef np.float32_t DTYPEf_t
DTYPEc = np.complex64
ctypedef np.complex64_t DTYPEc_t

@cython.boundscheck(False)
@cython.nonecheck(False)
def cython_corr(
    np.ndarray[DTYPEf_t, ndim = 3] frame_a, 
    np.ndarray[DTYPEf_t, ndim = 3] frame_b,
):
    cdef int ind, k, m, n
    
    k = frame_a.shape[0]
    m = frame_a.shape[1] # size y of sample
    n = frame_a.shape[2] # size x of sample
    
    cdef DTYPEf_t[:,:] window_a = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample a
    window_a[:,:] = 0.
    
    cdef DTYPEf_t[:,:] window_b = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # sample b
    window_b[:,:] = 0.
    
    cdef DTYPEf_t[:,:] corr = pyfftw.empty_aligned([m,n], dtype = DTYPEf) # cross-corr matrix
    corr[:,:] = 0.
    
    cdef DTYPEf_t[:,:,:] out = pyfftw.empty_aligned([k,m,n], dtype = DTYPEf) # out
    out[:,:] = 0.
        
    cdef object fft_forward
    cdef object fft_backward
    
    cdef DTYPEc_t[:,:] f2a = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample a
    f2a[:,:] = 0. + 0.j
    
    cdef DTYPEc_t[:,:] f2b = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # rfft out of sample b
    f2b[:,:] = 0. + 0.j
    
    cdef DTYPEc_t[:,:] r = pyfftw.empty_aligned([m, n//2+1], dtype = DTYPEc) # power spectrum of sample a and b
    r[:,:] = 0. + 0.j
    
    fft_forward = pyfftw.builders.rfft2(
        pyfftw.empty_aligned([m,n], dtype = DTYPEf),
        axes = [0,1],
    )
    fft_backward = pyfftw.builders.irfft2(
        pyfftw.empty_aligned([m,n//2+1], dtype = DTYPEc),
        axes = [0,1],
    )
    for ind in range(k):
        window_a = frame_a[ind,:,:]
        window_b = frame_b[ind,:,:]
        r = np.conj(fft_forward(window_a))*fft_forward(window_b) # power spectrum of sample a and b
        corr = fft_backward(r).real # cross correlation
        corr = fftshift(corr, axes = [0,1]) # shift Q1 --> Q3, Q2 --> Q4
        # the fftshift could be moved out of the loop, but lets use that as a last resort :)
        out[ind,:,:] = corr
    return out

Test for methods:

import time
aa = bb = np.empty([14000, 24,24]).astype('float32') # a small test with 14000 24x24px samples
print(f'Number of samples: {aa.shape[0]}')

start = time.time()
corr = python_corr(aa, bb)
print(f'Time for Python: {time.time() - start}')
del corr

start = time.time()
corr = cython_corr(aa, bb)
print(f'Time for Cython: {time.time() - start}')
del corr
Zim-Zim
  • 11
  • 1
  • Cython cannot optimize Numpy calls, only basic code so this is why it is not faster as almost all the time is spent in the FFT computation: rfft2, irfft2 and fftshift. The rest is just a copy or the creation of view that Cython cannot really speed up too. However, you can try to run this in *parallel* (using Cython). – Jérôme Richard Jan 31 '22 at 09:00

0 Answers0