1

Let's take an example list of lists like this:

li=[[0.99, 0.002],
 [0.98, 0.0008, 0.0007],
 [0.97, 0.009, 0.001],
 [0.86, 0.001]]

Note that elements inside each sublist are sorted in descending order and their sum is always less than or equal to 1. Also, the sublists themselves are sorted in descending order of their first elements.

I am interested to find combinations, taking one element from each sublist such that the product of the elements of the combination is above a certain threshold, say 1e-5. One way that I found of doing this is by using itertools.product.

a = list(itertools.product(*li))
[item for item in a if np.prod(item)>1e-5]

But, this procedure is not feasible for me since my actual list has too many sublists and so the number of possible combinations to check is too big.

Instead of first finding all combinations and checking for the threshold condition, I must do the opposite i.e. only find combinations that satisfy the given condition. For example: since 0.002*0.0008*0.009 is already less than 1e-5, I can ignore all other combinations that start with (0.002, 0.0008,0.009,...).

I could not find an easy way to implement this. What I have in mind is a tree data structure, where I build a tree such that each node will keep track of the product and as soon as a node value is below 1e-5, I stop building further the tree on that node and also on nodes that are to it's right (since the nodes on the right will be smaller than the current node).

A simple tree skeleton to get started:

class Tree(object):
    def __init__(self, node=None):
        self.node = node
        self.children = []

    def add_child(self, child):
        self.children.append(child)

Once, the tree is built, I would then extract the combination that reached the depth = len(li)

enter image description here

Any help to build such a tree or any other ideas towards solving the problem would be highly appreciated. Thanks!

user0000
  • 77
  • 6
  • I'm thinking you might be able to use recursion here to make a clever solution. Let C(t, n) represent the number of valid combinations of numbers from the last n sublists whose products are greater than t. As an example, if we have a threshold of 0.1 and the same 4 sublists, then the number of solutions C(0.1, 4) is equal to the number of solutions C(0.1/0.99, 3)+C(0.1/0.002, 3), and so on. You can implement C, and add in a short-circuit check before the recursive step to prevent unnecessary work. – StardustGogeta Jul 23 '19 at 15:14

1 Answers1

2

Because your items and their subitems are all sorted and between 0 and 1, the output from itertools.product is nonincreasing. Math. No surprise there as you pointed that out, but how do you take advantage of that ...

I think what you want is a duplication of itertools.product with a shortcut to prune the branch as soon as the product goes under the threshold. This will allow you to efficiently iterate through all possible matches without wasting time re-checking products that you already know can't meet the threshold.

I found an iterator implementation of itertools.product here: how code a function similar to itertools.product in python 2.5 (I'm using python 3, and it seems to work okay.)

so I just copied it, and inserted a threshold check inside the loops

# cutoff function
from functools import reduce
from operator import mul

threshold = 1e-5

def cutoff(args):
    if args:
        return reduce(mul, args) < threshold
    return False

# alternative implementation of itertools.product with cutoff
def product(*args, **kwds):
    def cycle(values, uplevel):
        for prefix in uplevel:       # cycle through all upper levels
            if cutoff(prefix):
                break
            for current in values:   # restart iteration of current level
                result = prefix + (current,)
                if cutoff(result):
                    break
                yield result

    stack = iter(((),))             
    for level in tuple(map(tuple, args)) * kwds.get('repeat', 1):
        stack = cycle(level, stack)  # build stack of iterators
    return stack

# your code here
li=[[0.99, 0.002],
    [0.98, 0.0008, 0.0007],
    [0.97, 0.009, 0.001],
    [0.86, 0.001]]

for a in product(*li):
    p = reduce(mul, a)
    print (p, a)

I get the same results if I leave out the cutoff, and just check p > threshold later.

(0.99, 0.98, 0.97, 0.86) 0.8093408399999998
(0.99, 0.98, 0.97, 0.001) 0.0009410939999999998
(0.99, 0.98, 0.009, 0.86) 0.007509348
(0.99, 0.98, 0.001, 0.86) 0.0008343719999999999
(0.99, 0.0008, 0.97, 0.86) 0.0006606864
(0.99, 0.0007, 0.97, 0.86) 0.0005781006
(0.002, 0.98, 0.97, 0.86) 0.0016350319999999998
(0.002, 0.98, 0.009, 0.86) 1.5170399999999998e-05

Kenny Ostrom
  • 5,639
  • 2
  • 21
  • 30