The trick is to reshape data
into 3D
tensor and then use np.tensordot
against weights[0]
and thus by-pass foo
creation, like so -
k = 30 # kernel size
data3D = data.reshape(data.shape[0],k,-1)
out = np.tensordot(data3D, weights[0], axes=(2,1)).reshape(-1,k**2)
Under the hoods, tensordot
uses transposing axes, reshaping and then np.dot
. So, using all that manual-labor to avoid the function call to tensordot
, we would have one, like so -
out = data.reshape(-1,data.shape[1]//k).dot(weights[0].T).reshape(-1,k**2)
Related post to understand tensordot
.
Sample run
Let's use a toy example to explain on what's going on to people who might not have understand the problem :
In [68]: # Toy setup and code run with original codes
...: k = 3 # kernel size, which is 30 in the original case
...:
...: data = np.random.rand(4,6)
...: w0 = np.random.rand(3,2) # this is weights[0]
...: foo = np.kron(np.identity(k), w0)
...: output_first_row = foo.dot(data[0])
So, the question is to get rid of the foo
creation step and get to output_first_row
and do this for all rows of data
.
The proposed solution is :
...: data3D = data.reshape(data.shape[0],k,-1)
...: vectorized_out = np.tensordot(data3D, w0, axes=(2,1)).reshape(-1,k**2)
Let's verify the results :
In [69]: output_first_row
Out[69]: array([ 0.11, 0.13, 0.34, 0.67, 0.53, 1.51, 0.17, 0.16, 0.44])
In [70]: vectorized_out
Out[70]:
array([[ 0.11, 0.13, 0.34, 0.67, 0.53, 1.51, 0.17, 0.16, 0.44],
[ 0.43, 0.23, 0.73, 0.43, 0.38, 1.05, 0.64, 0.49, 1.41],
[ 0.57, 0.45, 1.3 , 0.68, 0.51, 1.48, 0.45, 0.28, 0.85],
[ 0.41, 0.35, 0.98, 0.4 , 0.24, 0.75, 0.22, 0.28, 0.71]])
Runtime test for all proposed approaches -
In [30]: import numpy as np
In [31]: sizes = [784,30,10]
In [32]: weights = [np.random.rand(y, x) for x, y in zip(sizes[:-1],sizes[1:])]
In [33]: data = np.random.rand(1666,23520)
In [37]: k = 30 # kernel size
# @Paul Panzer's soln
In [38]: %timeit (weights[0] @ data.reshape(-1, 30, 784).swapaxes(1, 2)).swapaxes(1, 2)
1 loops, best of 3: 707 ms per loop
In [39]: %timeit np.tensordot(data.reshape(data.shape[0],k,-1), weights[0], axes=(2,1)).reshape(-1,k**2)
10 loops, best of 3: 114 ms per loop
In [40]: %timeit data.reshape(-1,data.shape[1]//k).dot(weights[0].T).reshape(-1,k**2)
10 loops, best of 3: 118 ms per loop
This Q&A
and the comments under, might help understand how tensordot
works better with tensors
.