1

I wrote a python code to sort a set of four 3x3 tables by the value of their first columns. Is there an easier way to do it, that is with less code, and maybe more efficient? Here is my code:

import numpy as np 

np.random.seed(4)

a = np.random.randint(10, size=(4, 3, 3))
ind = a[:,:,0].argsort()
ind = np.stack(a.shape[2]*[ind], axis=1)
b = np.take_along_axis(a.transpose(0, 2, 1), ind, axis=2).transpose(0, 2, 1)

print(a)
print("----------------")
print(b)
[[[7 5 1]
  [8 7 8]
  [2 9 7]]

 [[7 7 9]
  [8 4 2]
  [6 4 3]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[5 8 1]
  [2 7 0]
  [8 3 1]]]
----------------
[[[2 9 7]
  [7 5 1]
  [8 7 8]]

 [[6 4 3]
  [7 7 9]
  [8 4 2]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[2 7 0]
  [5 8 1]
  [8 3 1]]]

1 Answers1

0

You have a nice solution to it. Here, is a shorter (and probably faster) one:

b = np.einsum('iijk->ijk', a[:,a[:,:,0].argsort()])

The einsum basically does what you are trying to achieve through indexing. It takes the i-th element of i-th element of a[:,a[:,:,0].argsort().

b:

[[[2 9 7]
  [7 5 1]
  [8 7 8]]

 [[6 4 3]
  [7 7 9]
  [8 4 2]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[2 7 0]
  [5 8 1]
  [8 3 1]]]
Ehsan
  • 12,072
  • 2
  • 20
  • 33
  • 1
    Thanks very much Ehsan, elegant solution, and much faster. ``` Mine: 22.8 µs ± 682 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) Yours: 6.13 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` – Delio Siret Apr 28 '20 at 00:35
  • @DelioSiret Glad it helped. Please go ahead and accept the answer if it solved the issue so others find it helpful too. Thank you. – Ehsan Apr 28 '20 at 00:51