You need to write an __eq__
function to define how to compare objects for equality. If you want sorting, then yo should have a __cmp__
function, and it makes the most sense to implement __eq__
in terms of __cmp__
.
def __eq__(self, other):
return cmp(self, other) == 0
You should probably also implement __hash__
, and you definitely should if you plan to put your objects into a set or dictionary. The default __hash__
for objects is id(), which effectively makes all objects unique(i.e. uniqueness is not based on object contents).
I wrote a base class/interface for a class that does this sort of equivalence comparison. You may find it useful:
class Comparable(object):
def attrs(self):
raise Exception("Must be implemented in concrete sub-class!")
def __values(self):
return (getattr(self, attr) for attr in self.attrs())
def __hash__(self):
return reduce(lambda x, y: 37 * x + hash(y), self.__values(), 0)
def __cmp__(self, other):
for s, o in zip(self.__values(), other.__values()):
c = cmp(s, o)
if c:
return c
return 0
def __eq__(self, other):
return cmp(self, other) == 0
def __lt__(self, other):
return cmp(self, other) < 0
def __gt__(self, other):
return cmp(self, other) > 0
if __name__ == '__main__':
class Foo(Comparable):
def __init__(self, x, y):
self.x = x
self.y = y
def attrs(self):
return ('x', 'y')
def __str__(self):
return "Foo[%d,%d]" % (self.x, self.y)
def foo_iter(x):
for i in range(x):
for j in range(x):
yield Foo(i, j)
for a in foo_iter(4):
for b in foo_iter(4):
if a<b: print "%(a)s < %(b)s" % locals()
if a==b: print "%(a)s == %(b)s" % locals()
if a>b: print "%(a)s > %(b)s" % locals()
The derived class must implement attrs()
that returns a tuple or list of the object's attributes that contribute to its identity (i.e. unchanging attributes that make it what it is). Most importantly, the code correctly handles equivalence where there are multiple attributes, and this is old school code that is often done incorrectly.