In my benchmarks a jitted numba
solution is the fastest, I could find
My benchmarks for a, m with shape (10000,200)
(equal result tensors)
|
|
|
1 |
@numba.jit |
13.2 ms (3.46x) |
2 |
list comprehension |
31.3 ms (1.46x) |
3 |
baseline |
45.7 ms (1.00x) |
Generation of sufficiently large sample data for benchmarking
import torch
import numpy as np
def generate_data(rows=500, columns=100):
a = torch.from_numpy(np.random.uniform(1,10, (rows,columns)).astype(np.float32))
# argsort trick by @divakar https://stackoverflow.com/a/55317373/14277722
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a,idx,axis=axis)
m = shuffle_along_axis(np.full((columns,rows), np.random.randint(2, size=rows)), 1).astype('bool').T
return a, np.ascontiguousarray(m)
a, m = generate_data(10000,200)
A jitted numba
implementation
import numba as nb
@nb.njit
def gather2d(arr1, arr2):
res = np.zeros((np.count_nonzero(arr2[:,0]), arr1.shape[1]), np.float32)
counter = np.zeros(arr1.shape[1], dtype=np.intp)
for i in range(arr1.shape[0]):
for j in range(arr1.shape[1]):
if arr2[i,j]:
res[counter[j], j] = arr1[i,j]
counter[j] += 1
return res
torch.from_numpy(gather2d(a.numpy(),m))
Output
# %timeit 10 loops, best of 5: 13.2 ms per loop
tensor([[2.1846, 7.8890, 8.8218, ..., 4.8309, 9.2853, 6.4404],
[5.8842, 3.7332, 6.7436, ..., 1.2914, 3.2983, 3.5627],
[9.5128, 2.4283, 2.2152, ..., 4.9512, 9.7335, 9.6252],
...,
[7.3193, 7.8524, 9.6654, ..., 3.3665, 8.8926, 4.7660],
[1.3829, 1.3347, 6.6436, ..., 7.1956, 4.0446, 6.4633],
[6.4264, 3.6283, 3.6385, ..., 8.4152, 5.8498, 5.0281]])
Against a vectorized baseline solution
# %timeit 10 loops, best of 5: 45.7 ms per loop
a.gather(0, torch.from_numpy(np.nonzero(m.T)[1].reshape(-1, m.shape[1], order='F')))
A python
list comprehension turns out to be surprisingly fast
def g(arr1,arr2):
return np.array([i[j] for i,j in zip(arr1.T,arr2.T)]).T
# %timeit 10 loops, best of 5: 31.3 ms per loop
torch.from_numpy(g(a.numpy(), m))