In this thread I found a way to bypass meshgrid when using simple numpy equations, by usethedeathstar: numpy - evaluate function on a grid of points
I had a similar problem but using list comprehension in the equation and tried giving it a shot, didn't think it would work:
import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def rastrigin(dim, x, A=10):
return dim + np.sum([(xi**2 - A * np.cos(2 * math.pi * xi)) for xi in x])
def main():
x = np.linspace(-4, 4, 100)
y = np.linspace(-4, 4, 100)
# Option 1 - bypass meshgrid - MAGIC!
#https://stackoverflow.com/questions/22774726/numpy-evaluate-function-on-a-grid-of-points/22778484#22778484
Z = rastrigin(2, [x[:,None], y[None,:]])
# Option 2 - traditional way using meshgrid
X,Y = np.meshgrid(x,y)
Z = np.array( [rastrigin(2, [x,y]) for x,y in zip(np.ravel(X), np.ravel(Y))] ).reshape(X.shape)
# timeit shows Option 1 is ridiculously faster than Option 2
import timeit
t1 = timeit.timeit(lambda: np.array( [rastrigin(2, [x,y]) for x,y in zip(np.ravel(X), np.ravel(Y))] ).reshape(X.shape) , number=100)
t2 = timeit.timeit(lambda: rastrigin(2, [x[:,None], y[None,:]]), number=100)
print(t1, t2)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.plasma, linewidth=0, antialiased=False)
plt.show()
if __name__ == "__main__":
main()
Not only this also works with list comprehension, something I think even the original author didn't intend, but is blazing fast. Option 1 runs in 0.003 seconds with timeit, Option 2 in 8.7 seconds.
My question is: how? I can't see why this approach works with list comprehension.
I understand this generates two arrays, one (100,1) and another (1,100): [(xi**2 - A * np.cos(2 * math.pi * xi)) for xi in x] . Then numpy.sum is propagating the sums and generating a (100,100) result? Is this expected behavior?