I wrote a python code to sort a set of four 3x3 tables by the value of their first columns. Is there an easier way to do it, that is with less code, and maybe more efficient? Here is my code:
import numpy as np
np.random.seed(4)
a = np.random.randint(10, size=(4, 3, 3))
ind = a[:,:,0].argsort()
ind = np.stack(a.shape[2]*[ind], axis=1)
b = np.take_along_axis(a.transpose(0, 2, 1), ind, axis=2).transpose(0, 2, 1)
print(a)
print("----------------")
print(b)
[[[7 5 1]
[8 7 8]
[2 9 7]]
[[7 7 9]
[8 4 2]
[6 4 3]]
[[0 7 5]
[5 9 6]
[6 8 2]]
[[5 8 1]
[2 7 0]
[8 3 1]]]
----------------
[[[2 9 7]
[7 5 1]
[8 7 8]]
[[6 4 3]
[7 7 9]
[8 4 2]]
[[0 7 5]
[5 9 6]
[6 8 2]]
[[2 7 0]
[5 8 1]
[8 3 1]]]