Is there a fast way of comparing rows for equivalency in an ndarray in Python 2.7? I am applying a symmetry operation on some coordinates that I am storing in each row of an array of shape (N,4). I need a way to tell if my transformation maps coordinates back to equivalent positions. The caveat is that even though the positions may be the same, they are stored in different rows of the array so this requires sorting the arrays prior to comparison. This is fine if I just needed to call it once, but this function is called ~10,000 times in my code.
Benchmarking this shows that this takes ~60 μs:
%timeit structs_are_equiv_old(a,b)
The slowest run took 6.36 times longer than the fastest. This could mean that
an intermediate result is being cached.
10000 loops, best of 3: 59.6 µs per loop
Is there a way to speed up this type of comparison?
def structs_are_equiv(a, b):
"""
compares two numpy arrays row by row to determine if they contain the
coordinates after the application of a transformation operation.
"""
assert a.shape == b.shape
a_temp = a[ np.lexsort( (a[:,3], a[:,2], a[:,1], a[:,0]) ) ]
b_temp = b[ np.lexsort( (b[:,3], b[:,2], b[:,1], b[:,0]) ) ]
return np.allclose( a_temp, b_temp )
Example a and b (note the first column is not involved in the transformation, just a way to denote the type of object stored at the coordinate):
a = array([[ 1. , 0. , 0.5 , 0.271149],
[ 1. , 0.5 , 0.5 , 0.271149],
[ 1. , 0. , 0. , 0.303063],
[ 1. , 0.5 , 0. , 0.303063],
[ 2. , 0.25 , 0. , 0.358071],
[ 2. , 0.75 , 0. , 0.358071],
[ 1. , 0.25 , 0.306215, 0.358071],
[ 1. , 0.75 , 0.306215, 0.358071],
[ 2. , 0. , 0.5 , 0.358071],
[ 2. , 0.5 , 0.5 , 0.358071],
[ 1. , 0.25 , 0.693785, 0.358071],
[ 1. , 0.75 , 0.693785, 0.358071],
[ 1. , 0. , 0. , 0.413078],
[ 1. , 0.5 , 0. , 0.413078],
[ 1. , 0. , 0.5 , 0.444992],
[ 1. , 0.5 , 0.5 , 0.444992],
[ 2. , 0. , 0. , 0.5 ],
[ 2. , 0.5 , 0. , 0.5 ],
[ 1. , 0.25 , 0.193785, 0.5 ],
[ 1. , 0.75 , 0.193785, 0.5 ],
[ 2. , 0.25 , 0.5 , 0.5 ],
[ 2. , 0.75 , 0.5 , 0.5 ],
[ 1. , 0.25 , 0.806215, 0.5 ],
[ 1. , 0.75 , 0.806215, 0.5 ],
[ 1. , 0. , 0.5 , 0.555008],
[ 1. , 0.5 , 0.5 , 0.555008],
[ 1. , 0. , 0. , 0.586922],
[ 1. , 0.5 , 0. , 0.586922],
[ 2. , 0.25 , 0. , 0.641929],
[ 2. , 0.75 , 0. , 0.641929],
[ 1. , 0.25 , 0.306215, 0.641929],
[ 1. , 0.75 , 0.306215, 0.641929],
[ 2. , 0. , 0.5 , 0.641929],
[ 2. , 0.5 , 0.5 , 0.641929],
[ 1. , 0.25 , 0.693785, 0.641929],
[ 1. , 0.75 , 0.693785, 0.641929],
[ 1. , 0. , 0. , 0.696937],
[ 1. , 0.5 , 0. , 0.696937],
[ 1. , 0. , 0.5 , 0.728851],
[ 1. , 0.5 , 0.5 , 0.728851],
[ 0. , 0.117635, 0.5 , 0.238728],
[ 0. , 0.617635, 0.5 , 0.238728],
[ 0. , 0. , 0.114216, 0.270642],
[ 0. , 0.5 , 0.114216, 0.270642],
[ 0. , 0. , 0. , 0.270642],
[ 0. , 0.5 , 0. , 0.270642],
[ 0. , 0.617635, 0.5 , 0.761272],
[ 0. , 0.117635, 0.5 , 0.761272],
[ 0. , 0.5 , 0.114216, 0.729358],
[ 0. , 0. , 0.114216, 0.729358],
[ 0. , 0.5 , 0. , 0.729358],
[ 0. , 0. , 0. , 0.729358],
[ 0. , 0.25 , 0.306215, 0.401299],
[ 0. , 0.75 , 0.306215, 0.401299],
[ 0. , 0.25 , 0.693785, 0.401299],
[ 0. , 0.75 , 0.693785, 0.401299],
[ 0. , 0.25 , 0.306215, 0.598701],
[ 0. , 0.75 , 0.306215, 0.598701],
[ 0. , 0.25 , 0.693785, 0.598701],
[ 0. , 0.75 , 0.693785, 0.598701],
[ 0. , 0.117635, 0.5 , 0.226923],
[ 0. , 0.117635, 0.5 , 0.773077],
[ 0. , 0. , 0.114216, 0.260279],
[ 0. , 0. , 0.114216, 0.739721],
[ 0. , 0. , 0.885784, 0.260279],
[ 0. , 0. , 0.885784, 0.739721],
[ 0. , 0.5 , 0.885784, 0.260279],
[ 0. , 0.5 , 0.885784, 0.739721],
[ 0. , 0.25 , 0.306215, 0.401299],
[ 0. , 0.25 , 0.306215, 0.598701],
[ 0. , 0.75 , 0.306215, 0.401299],
[ 0. , 0.75 , 0.306215, 0.598701],
[ 0. , 0.75 , 0.693785, 0.401299],
[ 0. , 0.75 , 0.693785, 0.598701]])
b = nparray([[ 1. , 0.5 , 0.5 , 0.271149],
[ 1. , 0. , 0.5 , 0.271149],
[ 1. , 0.5 , 0. , 0.303063],
[ 1. , 0. , 0. , 0.303063],
[ 2. , 0.75 , 0. , 0.358071],
[ 2. , 0.25 , 0. , 0.358071],
[ 1. , 0.75 , 0.306215, 0.358071],
[ 1. , 0.25 , 0.306215, 0.358071],
[ 2. , 0.5 , 0.5 , 0.358071],
[ 2. , 0. , 0.5 , 0.358071],
[ 1. , 0.75 , 0.693785, 0.358071],
[ 1. , 0.25 , 0.693785, 0.358071],
[ 1. , 0.5 , 0. , 0.413078],
[ 1. , 0. , 0. , 0.413078],
[ 1. , 0.5 , 0.5 , 0.444992],
[ 1. , 0. , 0.5 , 0.444992],
[ 2. , 0.5 , 0. , 0.5 ],
[ 2. , 0. , 0. , 0.5 ],
[ 1. , 0.75 , 0.193785, 0.5 ],
[ 1. , 0.25 , 0.193785, 0.5 ],
[ 2. , 0.75 , 0.5 , 0.5 ],
[ 2. , 0.25 , 0.5 , 0.5 ],
[ 1. , 0.75 , 0.806215, 0.5 ],
[ 1. , 0.25 , 0.806215, 0.5 ],
[ 1. , 0.5 , 0.5 , 0.555008],
[ 1. , 0. , 0.5 , 0.555008],
[ 1. , 0.5 , 0. , 0.586922],
[ 1. , 0. , 0. , 0.586922],
[ 2. , 0.75 , 0. , 0.641929],
[ 2. , 0.25 , 0. , 0.641929],
[ 1. , 0.75 , 0.306215, 0.641929],
[ 1. , 0.25 , 0.306215, 0.641929],
[ 2. , 0.5 , 0.5 , 0.641929],
[ 2. , 0. , 0.5 , 0.641929],
[ 1. , 0.75 , 0.693785, 0.641929],
[ 1. , 0.25 , 0.693785, 0.641929],
[ 1. , 0.5 , 0. , 0.696937],
[ 1. , 0. , 0. , 0.696937],
[ 1. , 0.5 , 0.5 , 0.728851],
[ 1. , 0. , 0.5 , 0.728851],
[ 0. , 0.617635, 0.5 , 0.238728],
[ 0. , 0.117635, 0.5 , 0.238728],
[ 0. , 0.5 , 0.114216, 0.270642],
[ 0. , 0. , 0.114216, 0.270642],
[ 0. , 0.5 , 0. , 0.270642],
[ 0. , 0. , 0. , 0.270642],
[ 0. , 0.117635, 0.5 , 0.761272],
[ 0. , 0.617635, 0.5 , 0.761272],
[ 0. , 0. , 0.114216, 0.729358],
[ 0. , 0.5 , 0.114216, 0.729358],
[ 0. , 0. , 0. , 0.729358],
[ 0. , 0.5 , 0. , 0.729358],
[ 0. , 0.75 , 0.306215, 0.401299],
[ 0. , 0.25 , 0.306215, 0.401299],
[ 0. , 0.75 , 0.693785, 0.401299],
[ 0. , 0.25 , 0.693785, 0.401299],
[ 0. , 0.75 , 0.306215, 0.598701],
[ 0. , 0.25 , 0.306215, 0.598701],
[ 0. , 0.75 , 0.693785, 0.598701],
[ 0. , 0.25 , 0.693785, 0.598701],
[ 0. , 0.117635, 0.5 , 0.226923],
[ 0. , 0.117635, 0.5 , 0.773077],
[ 0. , 0. , 0.114216, 0.260279],
[ 0. , 0. , 0.114216, 0.739721],
[ 0. , 0. , 0.885784, 0.260279],
[ 0. , 0. , 0.885784, 0.739721],
[ 0. , 0.75 , 0.306215, 0.401299],
[ 0. , 0.75 , 0.306215, 0.598701],
[ 0. , 0.25 , 0.306215, 0.401299],
[ 0. , 0.25 , 0.306215, 0.598701],
[ 0. , 0.75 , 0.693785, 0.401299],
[ 0. , 0.75 , 0.693785, 0.598701],
[ 0. , 0.25 , 0.693785, 0.401299],
[ 0. , 0.25 , 0.693785, 0.598701]])