I have come up with a solution based on the trie data structure as described here. Tries make it relatively fast to determine whether one of the stored sets is a subset of another given set (Savnik, 2013).
The solution then looks as follows:
- Create a trie
- Iterate through the given sets
- In each iteration, go through the sets in the trie and check if they are disjoint with the new set.
- If they are, continue; if not, add corresponding new sets to the trie unless they are supersets of sets in the trie.
The worst-case runtime is O(n m c), whereby m is the maximal number of solutions if we consider only n' <= n of the input sets, and c is the time factor from the subset lookups.
The code is below. I have implemented the algorithm based on the python package datrie, which is a wrapper around an efficent C implementation of a trie. The code below is in cython but can be converted to pure python easily by removing/exchangin cython specific commands.
The extended trie implementation:
from datrie cimport BaseTrie, BaseState, BaseIterator
cdef bint has_subset_c(BaseTrie trie, BaseState trieState, str setarr,
int index, int size):
cdef BaseState trieState2 = BaseState(trie)
cdef int i
trieState.copy_to(trieState2)
for i in range(index, size):
if trieState2.walk(setarr[i]):
if trieState2.is_terminal() or has_subset_c(trie, trieState2, setarr,
i, size):
return True
trieState.copy_to(trieState2)
return False
cdef class SetTrie():
def __init__(self, alphabet, initSet=[]):
if not hasattr(alphabet, "__iter__"):
alphabet = range(alphabet)
self.trie = BaseTrie("".join(chr(i) for i in alphabet))
self.touched = False
for i in initSet:
self.trie[chr(i)] = 0
if not self.touched:
self.touched = True
def has_subset(self, superset):
cdef BaseState trieState = BaseState(self.trie)
setarr = "".join(chr(i) for i in superset)
return bool(has_subset_c(self.trie, trieState, setarr, 0, len(setarr)))
def extend(self, sets):
for s in sets:
self.trie["".join(chr(i) for i in s)] = 0
if not self.touched:
self.touched = True
def delete_supersets(self):
cdef str elem
cdef BaseState trieState = BaseState(self.trie)
cdef BaseIterator trieIter = BaseIterator(BaseState(self.trie))
if trieIter.next():
elem = trieIter.key()
while trieIter.next():
self.trie._delitem(elem)
if not has_subset_c(self.trie, trieState, elem, 0, len(elem)):
self.trie._setitem(elem, 0)
elem = trieIter.key()
if has_subset_c(self.trie, trieState, elem, 0, len(elem)):
val = self.trie.pop(elem)
if not has_subset_c(self.trie, trieState, elem, 0, len(elem)):
self.trie._setitem(elem, val)
def update_by_settrie(self, SetTrie setTrie, maxSize=inf, initialize=True):
cdef BaseIterator trieIter = BaseIterator(BaseState(setTrie.trie))
cdef str s
if initialize and not self.touched and trieIter.next():
for s in trieIter.key():
self.trie._setitem(s, 0)
self.touched = True
while trieIter.next():
self.update(set(trieIter.key()), maxSize, True)
def update(self, otherSet, maxSize=inf, isStrSet=False):
if not isStrSet:
otherSet = set(chr(i) for i in otherSet)
cdef str subset, newSubset, elem
cdef list disjointList = []
cdef BaseTrie trie = self.trie
cdef int l
cdef BaseIterator trieIter = BaseIterator(BaseState(self.trie))
if trieIter.next():
subset = trieIter.key()
while trieIter.next():
if otherSet.isdisjoint(subset):
disjointList.append(subset)
trie._delitem(subset)
subset = trieIter.key()
if otherSet.isdisjoint(subset):
disjointList.append(subset)
trie._delitem(subset)
cdef BaseState trieState = BaseState(self.trie)
for subset in disjointList:
l = len(subset)
if l < maxSize:
if l+1 > self.maxSizeBound:
self.maxSizeBound = l+1
for elem in otherSet:
newSubset = subset + elem
trieState.rewind()
if not has_subset_c(self.trie, trieState, newSubset, 0,
len(newSubset)):
trie[newSubset] = 0
def get_frozensets(self):
return (frozenset(ord(t) for t in subset) for subset in self.trie)
def clear(self):
self.touched = False
self.trie.clear()
def prune(self, maxSize):
cdef bint changed = False
cdef BaseIterator trieIter
cdef str k
if self.maxSizeBound > maxSize:
self.maxSizeBound = maxSize
trieIter = BaseIterator(BaseState(self.trie))
k = ''
while trieIter.next():
if len(k) > maxSize:
self.trie._delitem(k)
changed = True
k = trieIter.key()
if len(k) > maxSize:
self.trie._delitem(k)
changed = True
return changed
def __nonzero__(self):
return self.touched
def __repr__(self):
return str([set(ord(t) for t in subset) for subset in self.trie])
This can be used as follows:
def cover_sets(sets):
strie = SetTrie(range(10), *([i] for i in sets[0]))
for s in sets[1:]:
strie.update(s)
return strie.get_frozensets()
Timing:
from timeit import timeit
s1 = {1, 2, 3}
s2 = {3, 4, 5}
s3 = {5, 6}
%timeit cover_sets([s1, s2, s3])
Result:
37.8 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Note that the trie implementation above works only with keys larger than (and not equal to) 0
. Otherwise, the integer to character mapping does not work properly. This problem can be solved with an index shift.