Here's a recursive generator in plain Python (i.e. not using PyTorch or Numpy) that produces permutations of range(n)
satisfying the given constraint.
First, we create a list out
to contain the output sequence, setting each slot in out
to -1 to indicate that slot is unused. Then, for each value i
we create a list avail
of indices in the permitted range that aren't already occupied. For each j
in avail
, we set out[j] = i
and recurse to place the next i
. When i == n
, all the i
have been placed, so we've reached the end of the recursion, and out
contains a valid solution, which gets propagated back up the recursion tree.
from random import shuffle
def rand_perm_gen(n, k):
out = [-1] * n
def f(i):
if i == n:
yield out
return
lo, hi = max(0, i-k+1), min(n, i+k)
avail = [j for j in range(lo, hi) if out[j] == -1]
if not avail:
return
shuffle(avail)
for j in avail:
out[j] = i
yield from f(i+1)
out[j] = -1
yield from f(0)
def test(n=10, k=3, numtests=10):
for j, a in enumerate(rand_perm_gen(n, k), 1):
print("\n", j, a)
for i, u in enumerate(a):
print(f"{i}: {u} -> {(u - i)}")
if j == numtests:
break
test()
Typical output
1 [1, 0, 3, 2, 6, 5, 4, 8, 9, 7]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 8 -> 1
8: 9 -> 1
9: 7 -> -2
2 [1, 0, 3, 2, 6, 5, 4, 9, 8, 7]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 9 -> 2
8: 8 -> 0
9: 7 -> -2
3 [1, 0, 3, 2, 6, 5, 4, 9, 7, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 9 -> 2
8: 7 -> -1
9: 8 -> -1
4 [1, 0, 3, 2, 6, 5, 4, 8, 7, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 8 -> 1
8: 7 -> -1
9: 9 -> 0
5 [1, 0, 3, 2, 6, 5, 4, 7, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 7 -> 0
8: 8 -> 0
9: 9 -> 0
6 [1, 0, 3, 2, 6, 5, 4, 7, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 5 -> 0
6: 4 -> -2
7: 7 -> 0
8: 9 -> 1
9: 8 -> -1
7 [1, 0, 3, 2, 6, 7, 4, 5, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 7 -> 2
6: 4 -> -2
7: 5 -> -2
8: 9 -> 1
9: 8 -> -1
8 [1, 0, 3, 2, 6, 7, 4, 5, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 6 -> 2
5: 7 -> 2
6: 4 -> -2
7: 5 -> -2
8: 8 -> 0
9: 9 -> 0
9 [1, 0, 3, 2, 5, 6, 4, 7, 9, 8]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 5 -> 1
5: 6 -> 1
6: 4 -> -2
7: 7 -> 0
8: 9 -> 1
9: 8 -> -1
10 [1, 0, 3, 2, 5, 6, 4, 7, 8, 9]
0: 1 -> 1
1: 0 -> -1
2: 3 -> 1
3: 2 -> -1
4: 5 -> 1
5: 6 -> 1
6: 4 -> -2
7: 7 -> 0
8: 8 -> 0
9: 9 -> 0
Here's a live version running on SageMathCell.
This approach is faster than generating all permutations and filtering them, but it is still slow for large n
. You can improve the speed by removing the shuffle
call, in which case the yielded permutations are in lexicographic order.
If you just want a single solution, use next
, eg
perm = next(rand_perm_gen(10, 3))
Note that all solutions share the same out
list. So if you need to save those solutions in a list you have to copy them, eg
perms = [seq.copy() for seq in rand_perm_gen(5, 2)]