Why are the "batch" axes always the leading axes in NumPy? I designed all my packages to use the trailing axes as batch axes because this seems more natural to me. Now I'm thinking about switching to NumPy's convention - just to make things more intuitive for NumPy users. Any ideas on that?
Performance-wise, this could be a really bad idea:
import numpy as np
np.random.seed(6512)
a = np.random.rand(50000, 8, 3, 3)
np.random.seed(85742)
b = np.random.rand(50000, 8, 3, 3)
c = a @ b
# 19.8 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
d = np.einsum("...ik,...kj->...ij", a, b)
# 84.1 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# now use the trailing axes (ensure C-contiguous arrays for transposed data)
A = np.ascontiguousarray(np.transpose(a, [2, 3, 0, 1])) # A_ijab
B = np.ascontiguousarray(np.transpose(b, [2, 3, 0, 1])) # B_ijab
C = (B.T @ A.T).T # (C^T)_baji = B_bajk A_baki -> C_ijab
# 16.9 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
D = np.einsum("ik...,kj...->ij...", A, B)
# 17.2 ms ± 842 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
assert np.allclose(c, d)
assert np.allclose(C, D)
assert np.allclose(np.transpose(D, [2, 3, 0, 1]), d)
assert np.allclose(np.transpose(C, [2, 3, 0, 1]), c)
Or more complicated einsums:
# crossed-dyadic product
# ----------------------
E = np.einsum("ik...,jl...->ijkl...", A, B)
# 76.5 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
e = np.einsum("...ik,...jl->...ijkl", a, b)
# 207 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
assert np.allclose(np.transpose(E, [4, 5, 0, 1, 2, 3]), e)