1

Why are the "batch" axes always the leading axes in NumPy? I designed all my packages to use the trailing axes as batch axes because this seems more natural to me. Now I'm thinking about switching to NumPy's convention - just to make things more intuitive for NumPy users. Any ideas on that?

Performance-wise, this could be a really bad idea:

import numpy as np

np.random.seed(6512)
a = np.random.rand(50000, 8, 3, 3)

np.random.seed(85742)
b = np.random.rand(50000, 8, 3, 3)

c = a @ b
# 19.8 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

d = np.einsum("...ik,...kj->...ij", a, b)
# 84.1 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# now use the trailing axes (ensure C-contiguous arrays for transposed data)
A = np.ascontiguousarray(np.transpose(a, [2, 3, 0, 1])) # A_ijab
B = np.ascontiguousarray(np.transpose(b, [2, 3, 0, 1])) # B_ijab

C = (B.T @ A.T).T # (C^T)_baji = B_bajk A_baki -> C_ijab
# 16.9 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

D = np.einsum("ik...,kj...->ij...", A, B)
# 17.2 ms ± 842 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

assert np.allclose(c, d)
assert np.allclose(C, D)
assert np.allclose(np.transpose(D, [2, 3, 0, 1]), d)
assert np.allclose(np.transpose(C, [2, 3, 0, 1]), c)

Or more complicated einsums:

# crossed-dyadic product
# ----------------------

E = np.einsum("ik...,jl...->ijkl...", A, B)
# 76.5 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

e = np.einsum("...ik,...jl->...ijkl", a, b)
# 207 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

assert np.allclose(np.transpose(E, [4, 5, 0, 1, 2, 3]), e)
Magoo
  • 77,302
  • 8
  • 62
  • 84
adtzlr
  • 45
  • 5
  • With C order the leading axis outermost, changing slowest. `matmul` is the moat obvious case of 'batch'. Other binary operators don't have a clear batch dimension. Some linalg functions take (...,N,M) shapes – hpaulj Aug 17 '23 at 19:09

1 Answers1

2

Numpy is written in C and use the C convention for arrays, that is the row-major array ordering. Thus, applying the operation on the last axis (i.e. right-most ones, and also the most contiguous ones), is more efficient for CPU caches. Transposing the array significantly increases the pressure on the RAM for large arrays so it often results in a much slower computation (the RAM is often the limiting factor for Numpy operations).

That being said, in your case, Numpy is clearly not optimized for 3x3 matrices. The overhead of the internal generic Numpy iterators (enabling broadcasting) is so huge in this case that the computation is bound by them. Most BLAS libraries are also not optimized for such extremely small matrices. Some linear algebra libraries provide batch operation for this (e.g. AFAIK CuBLAS does that). However, Numpy does not support them yet.

Modern mainstream CPUs can compute 3x3 matrix multiplications in only few nanoseconds, so the overhead of generic codes is too big to compute them efficient. To get a fast implementation, you need to write a compiled code supporting specifically fixed-sized 3x3 matrix. Compilers can then generate efficient instructions chosen for this specific case. Hand-written assembly codes (or compiler SIMD intrinsics) can certainly be significantly faster for this use-case, but they are hard to write, to maintain, and also bug-prone. The good solution is to use Cython (with memory-views and the right compilation flags) or even Numba in this case (if possible with the fast-math flag). You can find an example of Numba code used to solve a similar problem here.

Vitalizzare
  • 4,496
  • 7
  • 13
  • 32
Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • The given example in my current implementation (result `D`) takes 43ns per matrix multiplication. What do you mean with a few ns? About One magnitude faster? – adtzlr Aug 18 '23 at 18:39
  • Modern processors can often reach >100 GFlops. For example, my i5-9600KF can reach 400 GFlops. Computing a 3x3 matrix requires `3*3*(3+2)=45` floating point operations. This means up to 8-9 billion matrix per second. However, this does not imply <1 ns/matrix because matrices can be computing in parallel using *multiple cores* and floating point operation have a latency (so the computation of several matrices need to be *pipelined*. On top of that, the maximum computing power is very hard to reach and sometimes impossible regarding the input layout and the way the processor exactly operate. – Jérôme Richard Aug 18 '23 at 19:03
  • One issue with 3x3 matrices is that this does not fit well in *SIMD registers* which can hold typically 2, 4 or 8 double-precision floating-point numbers (DP-FP). This means compilers may not generate an SIMD code but a scalar one resulting in a 2 to 8 times slower code (4 on an average mainstream x86-64 CPU). Still, even in this case, I think my CPU can easily compute at least >0.25 billion matrices per second. That is, at least 10 times faster than your current code. Regarding the latency, I think my processor can compute 1 matrix in less than 5 ns on 1 core using a clever SIMD code. – Jérôme Richard Aug 18 '23 at 19:22
  • I wrote a naive implementation in C and it takes 31~35 cycle/matrix/core that is about 8 ns/matrix/core. I have 6 cores. So a naive C implementation can be much faster. The bad news is that the arrays are too big to fit in my CPU cache so they are stored in RAM which much slower and it should be the limiting factor. Still, a naive implementation should be 5-6 times faster on my machine. A clever implementation should be 7-8 times faster. If you can compute the resulting matrices on-the-fly, then it can be even faster. Note that it can be about twice faster with simple-precision numbers. – Jérôme Richard Aug 18 '23 at 19:43