First, the powerset:
For n distinct elements, you can create 2n sets which can easily shown when considering the question “is this element included in this particular set?” a boolean value:
For n=3:
0: 0 0 0 none included
1: 0 0 1 first included
2: 0 1 0 second included
3: 0 1 1 first and second one included
4: 1 0 0 third included
5: 1 0 1 first and third included
6: 1 1 0 second and third included
7: 1 1 1 all included
So iterating over all combinations can be implemented by iterating over integer numbers from 0 to 2ⁿ and using the bit pattern of each number to select the elements from the original set (we have to copy them into an ordered structure like a List
for that):
public static <T> Set<Set<T>> allPermutations(Set<T> input) {
List<T> sequence = new ArrayList<>(input);
long count = sequence.size() > 62? Long.MAX_VALUE: 1L << sequence.size();
HashSet<Set<T>> result = new HashSet<>((int)Math.min(Integer.MAX_VALUE, count));
for(long l = 0; l >= 0 && l < count; l++) {
if(l == 0) result.add(Collections.emptySet());
else if(Long.lowestOneBit(l) == l)
result.add(Collections.singleton(sequence.get(Long.numberOfTrailingZeros(l))));
else {
HashSet<T> next = new HashSet<>((int)(Long.bitCount(l)*1.5f));
for(long tmp = l; tmp != 0; tmp-=Long.lowestOneBit(tmp)) {
next.add(sequence.get(Long.numberOfTrailingZeros(tmp)));
}
result.add(next);
}
}
return result;
}
Then
Set<String> input = new HashSet<>();
Collections.addAll(input, "1", "2", "3");
System.out.println(allPermutations(input));
gives us
[[], [1], [2], [1, 2], [3], [1, 3], [2, 3], [1, 2, 3]]
To utilize this for identifying the partitions, we have to expand the logic to use the counter’s bits to select bits from another mask, which will identify the actual elements to include. Then, we can re-use the same operation to get the partitions of the elements not included so far, using a simple binary not operation:
public static <T> Set<Set<Set<T>>> allPartitions(Set<T> input) {
List<T> sequence = new ArrayList<>(input);
if(sequence.size() > 62) throw new OutOfMemoryError();
return allPartitions(sequence, (1L << sequence.size()) - 1);
}
private static <T> Set<Set<Set<T>>> allPartitions(List<T> input, long bits) {
long count = 1L << Long.bitCount(bits);
if(count == 1) {
return Collections.singleton(new HashSet<>());
}
Set<Set<Set<T>>> result = new HashSet<>();
for(long l = 1; l >= 0 && l < count; l++) {
long select = selectBits(l, bits);
final Set<T> first = get(input, select);
for(Set<Set<T>> all: allPartitions(input, bits&~select)) {
all.add(first);
result.add(all);
}
}
return result;
}
private static long selectBits(long selected, long mask) {
long result = 0;
for(long bit; selected != 0; selected >>>= 1, mask -= bit) {
bit = Long.lowestOneBit(mask);
if((selected & 1) != 0) result |= bit;
}
return result;
}
private static <T> Set<T> get(List<T> elements, long bits) {
if(bits == 0) return Collections.emptySet();
else if(Long.lowestOneBit(bits) == bits)
return Collections.singleton(elements.get(Long.numberOfTrailingZeros(bits)));
else {
HashSet<T> next = new HashSet<>();
for(; bits != 0; bits-=Long.lowestOneBit(bits)) {
next.add(elements.get(Long.numberOfTrailingZeros(bits)));
}
return next;
}
}
Then,
Set<String> input = new HashSet<>();
Collections.addAll(input, "1", "2", "3");
System.out.println(allPartitions(input));
gives us
[[[1], [2], [3]], [[1], [2, 3]], [[2], [1, 3]], [[3], [1, 2]], [[1, 2, 3]]]
whereas
Set<String> input = new HashSet<>();
Collections.addAll(input, "1", "2", "3", "4");
for(Set<Set<String>> partition: allPartitions(input))
System.out.println(partition);
yields
[[1], [3], [2, 4]]
[[1], [2], [3], [4]]
[[1], [2], [3, 4]]
[[1], [4], [2, 3]]
[[4], [1, 2, 3]]
[[1], [2, 3, 4]]
[[2], [3], [1, 4]]
[[2], [4], [1, 3]]
[[1, 2, 3, 4]]
[[1, 3], [2, 4]]
[[1, 4], [2, 3]]
[[2], [1, 3, 4]]
[[3], [1, 2], [4]]
[[1, 2], [3, 4]]
[[3], [1, 2, 4]]