Doing something like
import numpy as np
a = np.random.rand(10**4, 10**4)
b = np.dot(a, a)
uses multiple cores, and it runs nicely.
The elements in a
, though, are 64-bit floats (or 32-bit in 32-bit platforms?), and I'd like to multiply 8-bit integer arrays. Trying the following, though:
a = np.random.randint(2, size=(n, n)).astype(np.int8)
results in the dot product not using multiple cores, and thus running ~1000x slower on my PC.
array: np.random.randint(2, size=shape).astype(dtype)
dtype shape %time (average)
float32 (2000, 2000) 62.5 ms
float32 (3000, 3000) 219 ms
float32 (4000, 4000) 328 ms
float32 (10000, 10000) 4.09 s
int8 (2000, 2000) 13 seconds
int8 (3000, 3000) 3min 26s
int8 (4000, 4000) 12min 20s
int8 (10000, 10000) It didn't finish in 6 hours
float16 (2000, 2000) 2min 25s
float16 (3000, 3000) Not tested
float16 (4000, 4000) Not tested
float16 (10000, 10000) Not tested
I understand NumPy uses BLAS, which doesn't support integers, but if I use the SciPy BLAS wrappers, ie.
import scipy.linalg.blas as blas
a = np.random.randint(2, size=(n, n)).astype(np.int8)
b = blas.sgemm(alpha=1.0, a=a, b=a)
the computation is multi-threaded. Now, blas.sgemm
runs with exactly the same timing as np.dot
for float32's, but for non-floats it converts everything to float32
and outputs floats, which is something np.dot
doesn't do. (In addition, b
is now in F_CONTIGUOUS
order, which is a lesser issue).
So, if I want to do integer matrix multiplication, I have to do one of the following:
- Use NumPy's painfully slow
np.dot
and be glad I get to keep the 8-bit integers. - Use SciPy's
sgemm
and use up 4x memory. - Use Numpy's
np.float16
and only use up 2x memory, with the caveat thatnp.dot
is much slower on float16 arrays than on float32 arrays, more so than int8. - Find an optimized library for multi-threaded integer matrix multiplication (actually, Mathematica does this, but I'd prefer a Python solution), ideally supporting 1-bit arrays, although 8-bit arrays is also fine... (I'm actually aiming to do multiplication of matrices over the finite field Z/2Z, and I know I can do this with Sage, which is quite Pythonic, but, again, is there something strictly Python?)
Can I follow option 4? Does such a library exist?
Disclaimer: I'm actually running NumPy + MKL, but I've tried a similar test on vanilly NumPy, with similar results.