2

I have a dictionary with a set of integers.

{'A': {9, 203, 404, 481},
 'B': {9},
 'C': {110},
 'D': {9, 314, 426},
 'E': {59, 395, 405}
}

You can generate the data with this:

data = {}
for i in string.ascii_uppercase:
    n = 25
    rng = np.random.default_rng()
    data[i] = set(rng.choice(100, size=n, replace=False))

I need to get a list of the intersect of subsets of the dictionary. So here in example the output of the intersect of ['A','B','D'] would return [9]

I've figured out 2 different ways of doing this but both are much to slow when the sets grow in value.

cols = ['A','B','D']

# method 1 
lis = list(map(data.get, cols))
idx = list(set.intersection(*lis))

#method 2 (10x slower then method 1)
query_dict = dict((k, data[k]) for k in cols)
idx2 = list(reduce(set.intersection, (set(val) for val in query_dict.values())))

When the sets grow (>10k ints per set) the runtime grows quickly.

I'm okay with using other datatypes then sets in the dict like lists or numpy arrays etc.

Is there a faster way of accomplishing this?

EDIT:

The original problem I had was this dataframe:

    T       S       A   B   C   D
0   49.378  1.057   AA  AB  AA  AA
1   1.584   1.107   BC  BA  AA  AA
2   1.095   0.000   BB  BB  AD  
3   10.572  1.224   BA  AB  AA  AA
4   0.000   0.000   DC  BA  AB  

For each row I have to sum 'T' over all rows which have A,B,C,D in common, if a threshold is reached continue else over B,C,D in common, then C,D and then only D if threshold still not reached.

However this was really slow, so first I tried with get_dummies and then take product of columns. However this was to slow so I moved to numpy arrays with indices to sum over. That is the fastest option up till now, however the intersect is the only things which still takes op too much time to compute.

EDIT2:

It turned out I was making it to hard on myself and it is possible with pandas groupby and that is very fast.

code:

parts = [['A','B','C','D'],['B','C','D'],['C','D'],['D']]
for part in parts:
    temp_df = df.groupby(part,as_index=False).sum()
    temp_df = temp_df[temp_df['T'] > 100]
    df = pd.merge(df,temp_df,on=part,how='left',suffixes=["","_" + "".join(part)])

df['T_sum'] = df[['T_ABCD','T_BCD','T_CD','T_D']].min(axis=1)
df['S_sum'] = df[['S_ABCD','S_BCD','S_CD','S_D']].min(axis=1)
df.drop(['T_ABCD','T_BCD','T_CD','T_D','S_ABCD','S_BCD','S_CD','S_D'],, axis=1, inplace=True)

probably the code can be a bit cleaner, but I don't know how to replace only NaN values in a merge.

user3605780
  • 6,542
  • 13
  • 42
  • 67
  • None of them seems to be returning 9? – yatu Feb 05 '20 at 14:09
  • @yatu I made a typo in the columns, 'C' should've been 'D'. I edited the post. – user3605780 Feb 05 '20 at 14:11
  • 1
    You might find this to be somewhat faster `list(reduce(set.intersection, itemgetter(*cols)(data)))` – yatu Feb 05 '20 at 14:35
  • @yatu, thanks this indeed speeds it up a bit, but it still grows exponentially in the number of elements in the set, which is a problem when the sets grow really large. – user3605780 Feb 05 '20 at 14:40
  • Are youre dict values originally sets? Can I assume they are come other structure? Such as lists? – yatu Feb 05 '20 at 14:48
  • @yatu, the sets are indices of values in a numpy array which I need to sum. Using indices was faster then using a boolean array. – user3605780 Feb 05 '20 at 14:58
  • Can you explain more please your original set up? So, what arrays you have and how this can be optimized working from those arrays directly if possible? – yatu Feb 05 '20 at 14:59
  • @yatu, I added the original problem in the question. – user3605780 Feb 05 '20 at 15:18
  • Neither method is exponential. Set intersection is linear time in the size of the smaller of the two sets being intersected, so the reduce method is O(mn) in the worst case where m is the size of each set and n is the number of intersections, but should be faster in practice because after a few intersections the accumulator set will usually be smaller. Your original solutions are also O(mn) but #2 is slower because it makes unnecessary copies of the sets first. – kaya3 Feb 05 '20 at 15:22
  • @kaya3, you're right, I wasn't accurate. It grows linear inside the loop, but the loop also grows linear so the total is O(n^2). – user3605780 Feb 05 '20 at 15:35
  • Not O(n^2), it's O(mn). The "outer loop" is only linear in the number of sets you are intersecting, not their sizes. – kaya3 Feb 05 '20 at 15:39
  • What is your n and what is the maximum number possible in each set? – kaya3 Feb 05 '20 at 15:43
  • @kaya3, rough estimates since it changes. Max n is 10 million -30 million and the columns a,b,c,d can be almost unique rows to 1 million in common. – user3605780 Feb 05 '20 at 15:46

1 Answers1

7

The problem here is how to efficiently find the intersection of several sets. According to the comments: "Max n is 10 million - 30 million and the columns a,b,c,d can be almost unique rows to 1 million in common." So the sets are large, but not all the same size. Set intersection is an associative and commutative operation, so we can take the intersections in any order we like.

The time complexity of intersecting two sets is O(min(len(set1), len(set2))), so we should choose an order to do the intersections in, which minimises the sizes of the intermediate sets.


If we don't know in advance which pairs of sets have small intersections, the best we can do is intersect them in order of size. After the first intersection, the smallest set will always be the result of the last intersection, so we want to intersect that with the next-smallest input set. It's better to use set.intersection on all of the sets at once rather than reduce here, because that's implemented essentially the same way as reduce would do it, but in C.

def intersect_sets(sets):
    return set.intersection(*sorted(sets, key=len))

In this case where we know nothing about the pairwise intersections, the only possible slowdown in the C implementation could be the unnecessary memory allocation for multiple intermediate sets. This can be avoided by e.g. { x for x in first_set if all(x in s for s in other_sets) }, but that turns out to be much slower.


I tested it with sets up to size 6 million, with about 10% pairwise overlaps. These are the times for intersecting four sets; after four, the accumulator is about 0.1% of the original size so any further intersections would take negligible time anyway. The orange line is for intersecting sets in the optimal order (smallest two first), and the blue line is for intersecting sets in the worst order (largest two first).

times

As expected, both take roughly linear time in the set sizes, but with a lot of noise because I didn't average over multiple samples. The optimal order is consistently about 2-3 times as fast as the worst order, measured on the same data, presumably because that's the ratio between the smallest and second-largest set sizes.

On my machine, intersecting four sets of size 2-6 million takes about 100ms, so going up to 30 million should take about half a second; I think it's very unlikely that you can beat that, but half a second should be fine. If it consistently takes a lot longer than that on your real data, then the issue will be to do with your data not being uniformly random. If that's the case then there's probably not much Stack Overflow can do for you beyond this, because improving the efficiency will depend highly on the particular distribution of your real data (though see below about the case where you have to answer many queries on the same sets).

My timing code is below.

import string
import random

def gen_sets(m, min_n, max_n):
    n_range = range(min_n, max_n)
    x_range = range(min_n * 10, max_n * 10)
    return [
        set(random.sample(x_range, n))
        for n in [min_n, max_n, *random.sample(n_range, m - 2)]
    ]

def intersect_best_order(sets):
    return set.intersection(*sorted(sets, key=len))

def intersect_worst_order(sets):
    return set.intersection(*sorted(sets, key=len, reverse=True))

from timeit import timeit

print('min_n', 'max_n', 'best order', 'worst order', sep='\t')
for min_n in range(100000, 2000001, 100000):
    max_n = min_n * 3
    data = gen_sets(4, min_n, max_n)
    t1 = timeit(lambda: intersect_best_order(data), number=1)
    t2 = timeit(lambda: intersect_worst_order(data), number=1)
    print(min_n, max_n, t1, t2, sep='\t')

If you need to do many queries, then it may be worth computing the pairwise intersections first:

from itertools import combinations

pairwise_intersection_sizes = {
    (a, b): set_a & set_b
    for ((a, set_a), (b, set_b)) in combinations(data.items(), 2)
}

If some intersections are much smaller than others, then the precomputed pairwise intersections can be used to choose a better order to do set.intersection in. Given some sets, you can choose the pair with the smallest precomputed intersection, then do set.intersection on that precomputed result along with the rest of the input sets. Especially in the non-uniform case where some pairwise intersections are nearly empty, this could be a big improvement.

kaya3
  • 47,440
  • 4
  • 68
  • 97
  • 1
    Although I figured out, that I was making it much to complicated for myself and could just solve it with the groupby function in a loop over the partitions, thanks for your awnser. I've accepted it because it's a really nice background explanation and useful to me and others. – user3605780 Feb 05 '20 at 18:55
  • OK - glad you were able to solve your problem by other means. – kaya3 Feb 05 '20 at 18:55