1

I am using the 'stars-and-bars' algorithm to select items from multiple lists, with the number of stars between the bars k and k+1 being the index in the k'th list. The problem I'm facing is that the partitions (i.e. the number of stars between two bars) can be larger than the size of a list, which will result in many invalid combinations.

For example: if I have two lists each of length 8, (14,0) is a valid stars distribution for sum=14, but will of course exceed the first list's capacity. (7,7) is the highest valid index - so I get a large number of invalid indices, especially if the lists are not of equal size.

For performance reasons I need a variant of the algorithm with limited partition size. How can I do this? The stars-and-bars implementation I'm using right now is this one, but I can easily change it. The lists are usually of similar length, but not necessarily of the same length. I'm fine with limiting the partition sizes to the length of the longest list, but individual restrictions would of course be nicer.

import itertools

def stars_and_bars(stars, partitions):
    for c in itertools.combinations(range(stars+partitions-1), partitions-1):
        yield tuple(right-left-1 for left,right in zip((-1,) + c, c + (stars+partitions-1,)))

def get_items(*args):
    hits = 0
    misses = 0
    tries = 0
    max_idx = sum(len(a) - 1 for a in args)
    for dist in range(max_idx):
        for indices in stars_and_bars(dist, len(args)):
            try:
                tries += 1
                [arg[i] for arg,i in zip(args,indices)]
                hits += 1
            except IndexError:
                misses += 1
                continue
    print('hits/misses/tries: {}/{}/{}'.format(hits, misses, tries))

# Generate 4 lists of length 1..4
lists = [[None]*(r+1) for r in range(4)]
get_items(*lists)
# hits/misses/tries: 23/103/126

Edit: I found two related questions on mathexchange, but I was not able to translate them into code yet:

Managarm
  • 1,070
  • 3
  • 12
  • 25

1 Answers1

1

Based on this post, here is some code to efficiently generate the solutions. The main differences with the other post, is that now buckets have different limits, and there is a fixed number of buckets, so the number of solutions isn't infinite.

def find_partitions(x, lims):
    # partition the number x in a list of buckets;
    # the number of elements of each bucket i is strictly smaller than lims[i];
    # the sum of all buckets is x;
    # output the lists of buckets one by one

    a = [x] + [0 for l in lims[1:]]  # create an output array of the same lenghth as lims, set a[0] to x

    while True:

        # step 1: while a[i] is too large: redistribute to a[i+1]
        i = 0
        while a[i] >= lims[i] and i < len(lims) - 1:
            a[i + 1] += a[i] - (lims[i] - 1)
            a[i] = (lims[i] - 1)
            i += 1
        if a[-1] >= lims[-1]:
            return # the last bucket has too many elements: we've reached the last partition;
                   # this only happens when x is too large

        yield a

        # step 2:  add one to group 1;
        #    while a group i is already full: set to 0 and increment group i+1;
        #    while the surplus is too large (because a[0] is too small): repeat incrementing
        i0 = 1
        surplus = 0
        while True:
            for i in range(i0, len(lims)):  # increment a[i] by 1, which can carry to the left
                if a[i] < lims[i]-1:
                    a[i] += 1
                    surplus += 1
                    break
                else:  # a[i] would become too full if 1 were added, therefore clear a[i] and increment a[i+1]
                    surplus -= a[i]
                    a[i] = 0
            else:  # the for-loop didn't find a small enough a[i]
                return

            if a[0] >= surplus:   # if a[0] is large enough to absorb the surplus, this step is done
                break
            else:  # a[0] would get negative to when absorbing the surplus, set a[i0] to 0 and start incrementing a[i0+1]
                surplus -= a[i0]
                a[i0] = 0
                i0 += 1
                if i0 == len(lims):
                    return

        # step 3: a[0] should absorb the surplus created in step 2, although a[0] can get be too large
        a[0] -= surplus


x = 11
lims = [5, 4, 3, 5]

for i, p in enumerate(find_partitions(x, lims)):
    print(f"partition {i+1}: {p} sums to {sum(p)}  lex: { ''.join([str(i) for i in p[::-1]]) }")

The 19 solutions for 0<=a[0]<5, 0<=a[1]<4, 0<a[2]<3, 0<a[3]<5, a[0]+a[1]+a[2]+a[3] == 11 (written from right to left they would be in increasing lexical order):

[4, 3, 2, 1]
[4, 3, 1, 2]
[4, 2, 2, 2]
[3, 3, 2, 2]
[4, 3, 0, 3]
[4, 2, 1, 3]
[3, 3, 1, 3]
[4, 1, 2, 3]
[3, 2, 2, 3]
[2, 3, 2, 3]
[4, 2, 0, 4]
[3, 3, 0, 4]
[4, 1, 1, 4]
[3, 2, 1, 4]
[2, 3, 1, 4]
[4, 0, 2, 4]
[3, 1, 2, 4]
[2, 2, 2, 4]
[1, 3, 2, 4]

In your test code, you can replace for indices in stars_and_bars(dist, len(args)): with for indices in find_partitions(dist, limits): where limits = [len(a) for a in args]. Then you'ld get hits/misses/tries: 23/0/23. To get all 24 solutions, the for-loop for dist should also allow the last one: for dist in range(max_idx+1):

PS: If you just want all possible combinations of elements from the lists, and you don't care about first getting the smallest indices, itertools.product generates them:

lists = [['a'], ['b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i', 'j']]
for i, p in enumerate(itertools.product(*lists)):
    print(i+1, p)
JohanC
  • 71,591
  • 8
  • 33
  • 66
  • Thank you for your answer and effort! :D This approach looks really promising, I hope I get to try it out soon. Have a nice holiday season! – Managarm Dec 18 '19 at 15:44