I read about convolutions being faster when computed into the frequency domain because it's "just" a matrix multiplication (in 2D), while in the time domain it's a lot of small matrix multiplication.
So I made this code we can see that FFT convolution is more complex than "normal" convolution. It's clear that something is wrong in my assumptions.
What is wrong ?
from sympy import exp, log, symbols, init_printing, lambdify
init_printing(use_latex='matplotlib')
import numpy as np
import matplotlib.pyplot as plt
def _complex_mult(n):
"""Complexity of a MatMul of a 2 matrices of size (n, n)"""
# see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
return n**2.5
def _complex_fft(n):
"""Complexity of fft and ifft"""
# see https://en.wikipedia.org/wiki/Fast_Fourier_transform
return n*log(n)
def fft_mult_fft(n, m):
"""Complexity of a convolution in the freq space.
fft -> mult between M and kernel -> ifft
"""
return _complex_fft(n) * 2 + _complex_mult(n)
def conv(n, m):
"""Complexity of a convolution in the time space.
for every n of M, we execute a MatMul of 2 (m, m) matrices
"""
return n*_complex_mult(m)
n = symbols('n') # size of M = (n, n)
m = symbols('m') # size of kernel = (m, m)
M = np.linspace(1, 1e3+1, 1e1)
kernel_size = np.linspace(2, 7, 7-2+1)**2
fft = fft_mult_fft(n, m)
discrete = conv(n, m)
f1 = lambdify(n, fft, 'numpy')
f2 = lambdify([n, m], discrete, 'numpy')
fig, ax = plt.subplots(1, len(kernel_size), figsize=(30, 10))
f1_computed = f1(M) # independant wrt m, do not compute it at each time
for i, size in enumerate(kernel_size):
ax[i].plot(M, f1_computed, c='red', label='freq domain (fft)')
ax[i].plot(M, f2(M, size), c='blue', label='time domain (normal)')
ax[i].legend(loc='upper left')
ax[i].set_title("kernel size = {}".format(size))
ax[i].set_xlabel("Matrix size")
ax[i].set_ylabel("Complexity")
And here is the output: (click to zoom)