2

Is there an easy way handle NaN attributes when testing dataclass objects for equality? Here is my minimal example:

import pickle
from dataclasses import dataclass


@dataclass
class MyClass:
    a: float


mc = MyClass(float('nan'))

# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))

assert mc2 == mc  # E   assert MyClass(a=nan) == MyClass(a=nan)

Current errors as follows:

Traceback (most recent call last):
  File "???.py", line 15, in <module>
    assert mc2 == mc  # E   assert MyClass(a=nan) == MyClass(a=nan)
AssertionError
rv.kvetch
  • 9,940
  • 3
  • 24
  • 53
GlaceCelery
  • 921
  • 1
  • 13
  • 30

3 Answers3

1

If you need to override the custom equality logic for a float then you're starting to move away from what dataclasses were designed for. It's only a small move, so dataclasses can still give you a bit of flexibility.

Dataclasses will honour any custom logic you write for your class. So you are completely free to write your own implementation of __eq__ that does what you want. For instance:

from dataclasses import dataclass
from math import isnan

@dataclass
class MyClass:
    a: float
    b: float

    def __eq__(self, other):
        return (
            self.__class__ is other.__class__
            and self.a == other.a
            and (
                self.b == other.b
                or (isnan(self.b) and isnan(other.b))
            )
        )

same_1 = MyClass(1, 2.5)
same_2 = MyClass(1, 2.5)
different_a = MyClass(2, 2.5)
different_b = MyClass(1, 3.0)
nan_1 = MyClass(1, float('nan'))
nan_2 = MyClass(1, float('nan'))

assert nan_1 is not nan_2 and nan_1.b is not nan_2.b, \
    "check equality operator cannot be short-circuited due to object identity"
assert nan_1 == nan_2, "equal when both bs are nan"

assert same_1 == same_1, "same object"
assert same_1 == same_2, "different object, but equal attributes"
assert same_1 != different_a, "different a attribute"
assert same_1 != different_b, "different b attribute"
assert same_1 != nan_1, "preserve nan inequality with other numbers"

You could play about with meta programming to reduce the amount of code in your class. However, if you provide a custom equality operator then your code will remain clear. It will be easy to see the reason you have implemented a custom equality operator, and how it works. Plus it will be marginally more efficient -- no transient object creation or additional instance checks required.

Dunes
  • 37,291
  • 7
  • 81
  • 97
0

I honestly don't know if this is the easiest or most straightforward way, but I'd start by creating a custom sub-class of the float type, which implements an equality or __eq__ method which implements the desired logic. So, something along the lines of this:

from math import isnan

class FloatWithEq(float):

    def __eq__(self, other, __super_eq__=float.__eq__):
        if isnan(other) and isnan(self):
            return True

        # noinspection PyArgumentList
        return __super_eq__(self, other)

Then, this can be used as the field type of a in the class as normal. Note that I've also went ahead and turned a into a field property, just so that it's a little easier to pass in to the constructor method: for example, you can pass in either a FloatWithEq or even a float or str type, which will automatically be converted to our custom type, as shown below.

from __future__ import annotations  # not needed in PY 3.10+

import pickle
from dataclasses import dataclass, field


@dataclass
class MyClass:
    a: FloatWithEq | float | str
    _a: FloatWithEq = field(init=False, repr=False)

    @property
    def a(self) -> FloatWithEq:
        return self._a

    @a.setter
    def a(self, a: FloatWithEq | float | str) -> None:
        self._a = a if isinstance(a, FloatWithEq) else FloatWithEq(a)


mc = MyClass(float('nan'))

# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))

assert mc2 == mc  # it works!

An even better approach, for code reusability and to simplify the above usage, is to make use of a newer concept called descriptors, and then define a descriptor class as follows:

class FloatWithEqDescriptor:

    def __set_name__(self, owner, name):
        self.private_name = '_' + name

    def __get__(self, obj, objtype=None):
        return getattr(obj, self.private_name)

    def __set__(self, obj, val):
        setattr(
            obj, self.private_name,
            val if isinstance(val, FloatWithEq) else FloatWithEq(val)
        )

And then this can be used after a similar fashion, such as below:

from __future__ import annotations  # not needed in PY 3.10+

import pickle
from dataclasses import dataclass

@dataclass
class MyClass:
    a: FloatWithEq | float | int | str = FloatWithEqDescriptor()
    b: FloatWithEq | float | int | str = FloatWithEqDescriptor()


mc = MyClass(float('nan'), 1.23)

# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))

assert mc2 == mc  # it works!
rv.kvetch
  • 9,940
  • 3
  • 24
  • 53
0

From python doc:

The not-a-number values float('NaN') and decimal.Decimal('NaN') are special. Any ordered comparison of a number to a not-a-number value is false. A counter-intuitive implication is that not-a-number values are not equal to themselves. For example, if x = float('NaN'), 3 < x, x < 3 and x == x are all false, while x != x is true. This behavior is compliant with IEEE 754.

That is a design feature of python. And, indeed, two not-number whatever are not necessarily the same thing. I would use something like

from math import isnan
assert assert all((isnan(mc2.a),isnan(mc.a)))

Though, by doing so you're not comparing the class data. Otherwise, you should implement a custom class data as the one suggested by rv.kvteck, or you can use a None type instead of float('nan') when needed:

myNan = None # or whatever you want
mc = MyClass(myNan)

# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))

assert mc==mc2
Buzz
  • 1,102
  • 1
  • 9
  • 24