This has to deal with a similar problem here: Calling BLAS / LAPACK directly using the SciPy interface and Cython but is different because I'm using the actual code in the SciPy example here _test_dgemm
: https://github.com/scipy/scipy/blob/master/scipy/linalg/cython_blas.pyx which is extremely fast (5x faster than numpy.dot
when an output matrix is input or 20x faster if not). It produces no results if Mx1 1xN vectors are passed. It produces the same values as numpy.dot
with matrices passed. I've minimized the code since no answers have been posted for clarity. Here's the dgemm.pyx.
:
import numpy as np
cimport numpy as np
from scipy.linalg.cython_blas cimport dgemm
from cython cimport boundscheck
@boundscheck(False)
cpdef int fast_dgemm(double[:,::1] a, double[:,::1] b, double[:,::1] c, double alpha=1.0, double beta=0.0) nogil except -1:
cdef:
char *transa = 'n'
char *transb = 'n'
int m, n, k, lda, ldb, ldc
double *a0=&a[0,0]
double *b0=&b[0,0]
double *c0=&c[0,0]
ldb = (&a[1,0]) - a0 if a.shape[0] > 1 else 1
lda = (&b[1,0]) - b0 if b.shape[0] > 1 else 1
k = b.shape[0]
if k != a.shape[1]:
with gil:
raise ValueError("Shape mismatch in input arrays.")
m = b.shape[1]
n = a.shape[0]
if n != c.shape[0] or m != c.shape[1]:
with gil:
raise ValueError("Output array does not have the correct shape.")
ldc = (&c[1,0]) - c0 if c.shape[0] > 1 else 1
dgemm(transa, transb, &m, &n, &k, &alpha, b0, &lda, a0,
&ldb, &beta, c0, &ldc)
return 0
Here's a sample test script:
import numpy as np;
a=np.random.randn(1000);
b=np.random.randn(1000);
a.resize(len(a),1);
a=np.array(a, order='c');
b.resize(1,len(b));
b=np.array(b, order='c');
c = np.empty((a.shape[0],b.shape[1]), float, order='c');
from dgemm import _test_dgemm;
_test_dgemm(a,b,c);
And if you want to play with it on Windows with Python 3.5 x64 here's the setup.py
to build it via the command prompt typing python setup.py build_ext --inplace --compiler=msvc
from Cython.Distutils import build_ext
import numpy as np
import os
try:
from setuptools import setup
from setuptools import Extension
except ImportError:
from distutils.core import setup
from distutils.extension import Extension
module = 'dgemm'
ext_modules = [Extension(module, sources=[module + '.pyx'],
include_dirs=['C://Program Files (x86)//Windows Kits//10//Include//10.0.10240.0//ucrt','C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//include','C://Program Files (x86)//Windows Kits//8.1//Include//shared'],
library_dirs=['C://Program Files (x86)//Windows Kits//8.1//bin//x64', 'C://Windows//System32', 'C://Program Files (x86)//Microsoft Visual Studio 14.0//VC//lib//amd64', 'C://Program Files (x86)//Windows Kits//8.1//Lib//winv6.3//um//x64', 'C://Program Files (x86)//Windows Kits//10//Lib//10.0.10240.0//ucrt//x64'],
extra_compile_args=['/Ot', '/favor:INTEL64', '/EHsc', '/GA'],
language='c++')]
setup(
name = module,
ext_modules = ext_modules,
cmdclass = {'build_ext': build_ext},
include_dirs = [np.get_include(), os.path.join(np.get_include(), 'numpy')]
)
Any help is much appreciated!