12

I am relatively new to Python. I am studying Disjoint sets, and implemented it as follows:

class DisjointSet:
    def __init__(self, vertices, parent):
        self.vertices = vertices
        self.parent = parent

    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            return self.find(self.parent[item])

    def union(self, set1, set2):
        self.parent[set1] = set2

Now in the driver code:

def main():
    vertices = ['a', 'b', 'c', 'd', 'e', 'h', 'i']
    parent = {}

    for v in vertices:
        parent[v] = v

    ds = DisjointSet(vertices, parent)
    print("Print all vertices in genesis: ")
    ds.union('b', 'd')

    ds.union('h', 'b')
    print(ds.find('h')) # prints d (OK)
    ds.union('h', 'i')
    print(ds.find('i')) # prints i (expecting d)

main()

So, at first I initialized all nodes as individual disjoint sets. Then unioned bd and hb which makes the set: hbd then hi is unioned, which should (as I assumed) give us the set: ihbd. I understand that due to setting the parent in this line of union(set1, set2):

self.parent[set1] = set2

I am setting the parent of h as i and thus removing it from the set of bd. How can I achieve a set of ihbd where the order of the params in union() won't yield different results?

Abrar
  • 6,874
  • 9
  • 28
  • 41
  • You shouldn't take the `parent` argument in the constructor, since the caller doesn't have any choice about what to specify. You should populate it in __init__ instead of main() – Matt Timmermans Jan 04 '19 at 19:00
  • 1
    Here's a Py impl: https://www.nayuki.io/res/disjoint-set-data-structure/disjointset.py. Another one: https://github.com/mrapacz/disjoint-set – Abhijit Sarkar Dec 23 '19 at 09:55

2 Answers2

12

Your program is not working correctly because you have misunderstood the algorithm for disjoint set implementation. Union is implemented by modifying the parent of the root node rather than the node provided as input. As you have already noticed, blindly modifying parents of any node you receive in input will just destroy previous unions.

Here's a correct implementation:

def union(self, set1, set2):
    root1 = self.find(set1)
    root2 = self.find(set2)
    self.parent[root1] = root2

I would also suggest reading Disjoint-set data structure for more info as well as possible optimizations.

merlyn
  • 2,273
  • 1
  • 19
  • 26
1

To make your implementation faster, you may want to update the parent as you find()

    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            res = self.find(self.parent[item])
            self.parent[item] = res
            return res
HK Tong
  • 83
  • 1
  • 10