0

Suppose a numpy array A with shape (n,), and a boolean numpy matrix B with shape (n,n).

If B[i][j] is True, then A[i] should be sorted to a position before A[j].

If B[i][j] is False, then A[i] should be sorted to a position after A[j].

These rules are applicable only if B[i][j] is below the main diagonal. Elements on the main diagonal or above the main diagonal should be ignored.

That being said, what is the most efficient way to sort A according to the matrix B?

I know there are several easy ways to do this, but I must perform this operation thousands of times, so I'm looking for a way to implement this that is computationally efficient ( readability is not my main concern ).

Ruan
  • 772
  • 4
  • 13
  • 1
    If you know some ways, show them with examples!. It's easier to suggest improvements than to create examples and code from scratch. – hpaulj Jul 12 '21 at 03:32

2 Answers2

1

I think this could work, I have tried to make it pure python3. It can be done simpler if you use numpy. I used https://stackoverflow.com/a/57003713/3895321.

from functools import cmp_to_key

B = [[True, False, True],
     [True, True, True],
     [False, False, True], ]

A = [2, 1, 3]

tmp = list(range(len(A)))

def compare(i, j):
    if B[i][j]:
        return -1
    else:
        return 1

tmp = sorted(tmp, key=cmp_to_key(compare))

sorted_A = [A[i] for i in tmp]
print(tmp)
print(sorted_A)
1

My handwritten merge sort that I passed threw numba.njit beats the pure python approach by almost a factor of 10. Note that this is technically an argsort i.e. you have to pass it the indices and get indices that if applied to the array make it sorted.

@numba.njit
def ceil_log2(x):
    n = 0
    while 2**n < x:
        n += 1
    return n
            
@numba.njit
def merge(arr, relation):
    out = np.zeros(len(arr), dtype='int')
    j = 0
    k = len(arr)//2
    for i in range(len(out)):
        if j == len(arr)//2:
            out[i:] = arr[k:]
            break
        elif k == len(arr):
            out[i:] = arr[j:len(arr)//2]
            break
        elif relation[arr[j], arr[k]]:
            out[i] = arr[j]
            j += 1
        else:
            out[i] = arr[k]
            k += 1
    arr[:] = out[:]

@numba.njit
def merge_sort(arr, relation):
    for i in range(1, 1+ceil_log2(len(arr))):
        idx = np.arange(len(arr))[::2**i]
        idx = [*idx, len(arr)]
        for i, j in zip(idx[:-1], idx[1:]):
            merge(arr[i:j], relation)
    return arr
def example_relation(arr):
    idx = np.arange(len(arr))
    grid = np.meshgrid(idx, idx)
    relation = arr[grid[1]] <= arr[grid[0]]
    return relation

np.random.seed(0)
array = np.random.normal(size=2**14)
relation = example_relation(array)

def compare(i, j):
    if relation[i, j]:
        return -1
    else:
        return 1

%time sorted(np.arange(len(array)), key=cmp_to_key(compare))
%time merge_sort(np.arange(len(array)), relation)

gives me

61.5 ms ± 785 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
6.98 ms ± 43.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Lukas S
  • 3,212
  • 2
  • 13
  • 25
  • What about wrapping `sorted(np.arange(len(array)), key=cmp_to_key(compare))` in a function and using `@numba.njit` to compile it? – xskxzr Jul 21 '21 at 03:21
  • @xskxzr Well if it was that easy Tim Peters (the author of python's sorting algorithm) would have probably done so himself. The problem is that `sorted` is not actually written in python but probably `c`. Keep in mind that it's not "worse" than my version since it is required to be way more flexible. – Lukas S Jul 21 '21 at 10:06