0

What is the most efficient way in python to select rows from a matrix with indices divisible by certain number (using numpy, torch or any)?

Example 1:

mat=[[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
num=3
output=[[1,2,3],[10,11,12]] #Rows 0,3

Example 2:

mat=[[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
num=2
output=[[1,2,3],[7,8,9],[13,14,15]] #Rows 0,2,4
max
  • 79
  • 1
  • 7
  • 1
    No substitute for `mat[::num]`, on `numpy` only. It exceeds performance of `mat[::num]` on lists in thousands if not millions :) – mathfux Oct 12 '20 at 17:28
  • More to the point, there is no substitute to basic research. Your questions at this stage in your studies are extremely unlikely to be earth shatteringly new. – Mad Physicist Oct 12 '20 at 18:46

2 Answers2

0

Let's set up an experiment! Let a, b, c be 1000 x 1000 matrix in NumPy, PyTorch and simple Python list of lists, respectively. Using Intel(R) Xeon(R) CPU @ 2.00GHz CPU & Tesla V100-SXM2 GPU.PyTorch version 1.6.0 w/ CUDA 10.1:

import torch
import numpy as np

k=1000
a = torch.arange(k**2, device='cuda').view(k, k) #On GPU
b = np.arange(k**2).reshape(k, k)
c = [[j * k + m % k for m in range(k)] for j in range(k)]

Using the IPython built-in magic command %%timeit we obtain the time it takes to slice each of the trio. We take here the worst time it took out of 100000 iterations, to avoid noise from caching. We check with num in range(1, 100) and here are the results, the lower the graph the better:

import matplotlib.pyplot as plt
ts_torch = []
ts_numpy = []
ts_lists = []
strides = range(2, 100)
for num in strides:
  t_torch = %timeit -o a[::num,:]
  t_numpy = %timeit -o b[::num,:]
  t_lists = %timeit -o c[::num]
  ts_torch.append(t_torch.worst)
  ts_numpy.append(t_numpy.worst)
  ts_lists.append(t_lists.worst)

plt.rcParams['figure.figsize'] = (10,10)
plt.plot(strides, ts_torch, c='r', label='PyTorch')
plt.plot(strides, ts_numpy, c='g', label='Numpy')
plt.plot(strides, ts_lists, c='b', label='Lists')
plt.legend(loc="upper right")
plt.xlabel('stride')
plt.ylabel('time [s]')
plt.show()

And the winners are (at least for this non 100% scientific experiment) - NumPy and Lists are head-to-head, Lists slicing with strides is slightly faster than numpy.ndarray in this setting.

enter image description here

Gil Pinsky
  • 2,388
  • 1
  • 12
  • 17
-1

If you want just a vanilla python way to do it, I would recommend a list comprehension. Using enumerate lets you see if the index one of the indexes you want.

mat = [[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
num = 2

out = [row  for i, row in enumerate(mat) if not i % num]

>>> [[1, 2, 3], [7, 8, 9], [13, 14, 15]]

Also if you want a numpy solution you can use:

import numpy as np
mat = np.matrix(mat)
out = mat[0::num, ]
>>> [[1, 2, 3], [7, 8, 9], [13, 14, 15]]
jlardy
  • 57
  • 7