0

I know the following PyTorch API can perform a global random shuffle for 1D array [0, ... , n-1]:

torch.randperm(n)

but I'm confused on how to quickly generate a random permutation, such that each element of the shuffled array satisfying:

K = 10  # should be positive
shuffled_array = rand_perm_func(n, K)  # a mysterious function
for i in range(n):
    print(abs(shuffled_array[i] - i) < K)  # should be True for each i

which means that each element is moved with a distance less than K. Does there exists fast implementations for 1D arrays and 2D arrays?


Thanks to @PM 2Ring, I wrote the following code:

import torch

# randomly produces a 1-D permutation index array,
# such that each element of the shuffled array has
# a distance less than K from its original location
def rand_perm(n, K):
    o = torch.arange(0, n)
    if K <= 1:
        return o
    while True:
        p = torch.randperm(n)
        d = abs(p - o) < K
        if bool(d.all()):
            return p

if __name__ == '__main__':
    for i in range(10):
        print(rand_perm(10, 2))

but it seems that when n is large, and K is small, the generation will take a very long time. Does there exists a more efficient implementation?

BinChen
  • 23
  • 11
  • Please go through the [intro tour](https://stackoverflow.com/tour), the [help center](https://stackoverflow.com/help) and [how to ask a good question](https://stackoverflow.com/help/how-to-ask) to see how this site works and to help you improve your current and future questions, which can help you get better answers. Asking for software recommendations or references is *specifically* listed as off-topic. – Prune Mar 29 '21 at 01:53
  • I don't know if there's a fast way to do this. There are various algorithms to compute the rank of a permutation (its lexicographic index), or produce a permutation from its rank, see eg https://stackoverflow.com/q/8940470/4014959 & some of the questions linked there. However, when those algorithms are looking for a slot for the next item they skip over slots that are already occupied, and I think that makes it hard to adapt them to your restriction. – PM 2Ring Mar 29 '21 at 05:56
  • FWIW, it's not too hard to produce your restricted permutations, using recursion. If you need multiple permutations for a given (n, K) pair, it's more efficient to produce them using a (recursive) generator, rather than producing each permutation from scratch. That also has the benefit of not producing duplicates. – PM 2Ring Mar 29 '21 at 06:00
  • Thank you~ I wrote a small function above, but the generating speed could be very slow in some cases. – BinChen Mar 29 '21 at 07:40
  • No worries. Your approach is (probably) ok when K is large (relative to n) and thus the probability of `randperm` producing a valid permutation is high. But otherwise, you end up rejecting most permutations, which wastes a lot of time. – PM 2Ring Mar 29 '21 at 08:20

1 Answers1

2

Here's a recursive generator in plain Python (i.e. not using PyTorch or Numpy) that produces permutations of range(n) satisfying the given constraint.

First, we create a list out to contain the output sequence, setting each slot in out to -1 to indicate that slot is unused. Then, for each value i we create a list avail of indices in the permitted range that aren't already occupied. For each j in avail, we set out[j] = i and recurse to place the next i. When i == n, all the i have been placed, so we've reached the end of the recursion, and out contains a valid solution, which gets propagated back up the recursion tree.

from random import shuffle

def rand_perm_gen(n, k):
    out = [-1] * n
    def f(i):
        if i == n:
            yield out
            return

        lo, hi = max(0, i-k+1), min(n, i+k)
        avail = [j for j in range(lo, hi) if out[j] == -1]
        if not avail:
            return
        shuffle(avail)
        for j in avail:
            out[j] = i
            yield from f(i+1)
            out[j] = -1

    yield from f(0)

def test(n=10, k=3, numtests=10):
    for j, a in enumerate(rand_perm_gen(n, k), 1):
        print("\n", j, a)
        for i, u in enumerate(a):
            print(f"{i}: {u} -> {(u - i)}")
        if j == numtests:
            break

test()

Typical output


 1 [1, 0, 3, 2, 6, 5, 4, 8, 9, 7]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 8 -> 1
8: 9 -> 1
9: 7 -> -2

 2 [1, 0, 3, 2, 6, 5, 4, 9, 8, 7]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 9 -> 2
8: 8 -> 0
9: 7 -> -2

 3 [1, 0, 3, 2, 6, 5, 4, 9, 7, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 9 -> 2
8: 7 -> -1
9: 8 -> -1

 4 [1, 0, 3, 2, 6, 5, 4, 8, 7, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 8 -> 1
8: 7 -> -1
9: 9 -> 0

 5 [1, 0, 3, 2, 6, 5, 4, 7, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 7 -> 0
8: 8 -> 0
9: 9 -> 0

 6 [1, 0, 3, 2, 6, 5, 4, 7, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 7 -> 0
8: 9 -> 1
9: 8 -> -1

 7 [1, 0, 3, 2, 6, 7, 4, 5, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 7 -> 2
6: 4 -> -2
7: 5 -> -2
8: 9 -> 1
9: 8 -> -1

 8 [1, 0, 3, 2, 6, 7, 4, 5, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 7 -> 2
6: 4 -> -2
7: 5 -> -2
8: 8 -> 0
9: 9 -> 0

 9 [1, 0, 3, 2, 5, 6, 4, 7, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 5 -> 1
5: 6 -> 1
6: 4 -> -2
7: 7 -> 0
8: 9 -> 1
9: 8 -> -1

 10 [1, 0, 3, 2, 5, 6, 4, 7, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 5 -> 1
5: 6 -> 1
6: 4 -> -2
7: 7 -> 0
8: 8 -> 0
9: 9 -> 0

Here's a live version running on SageMathCell.

This approach is faster than generating all permutations and filtering them, but it is still slow for large n. You can improve the speed by removing the shuffle call, in which case the yielded permutations are in lexicographic order.

If you just want a single solution, use next, eg

perm = next(rand_perm_gen(10, 3))

Note that all solutions share the same out list. So if you need to save those solutions in a list you have to copy them, eg

perms = [seq.copy() for seq in rand_perm_gen(5, 2)]
PM 2Ring
  • 54,345
  • 6
  • 82
  • 182