I honestly don't know if this is the easiest or most straightforward way, but I'd start by creating a custom sub-class of the float
type, which implements an equality or __eq__
method which implements the desired logic. So, something along the lines of this:
from math import isnan
class FloatWithEq(float):
def __eq__(self, other, __super_eq__=float.__eq__):
if isnan(other) and isnan(self):
return True
# noinspection PyArgumentList
return __super_eq__(self, other)
Then, this can be used as the field type of a
in the class as normal. Note that I've also went ahead and turned a
into a field property, just so that it's a little easier to pass in to the constructor method: for example, you can pass in either a FloatWithEq
or even a float
or str
type, which will automatically be converted to our custom type, as shown below.
from __future__ import annotations # not needed in PY 3.10+
import pickle
from dataclasses import dataclass, field
@dataclass
class MyClass:
a: FloatWithEq | float | str
_a: FloatWithEq = field(init=False, repr=False)
@property
def a(self) -> FloatWithEq:
return self._a
@a.setter
def a(self, a: FloatWithEq | float | str) -> None:
self._a = a if isinstance(a, FloatWithEq) else FloatWithEq(a)
mc = MyClass(float('nan'))
# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))
assert mc2 == mc # it works!
An even better approach, for code reusability and to simplify the above usage, is to make use of a newer concept called descriptors, and then define a descriptor class as follows:
class FloatWithEqDescriptor:
def __set_name__(self, owner, name):
self.private_name = '_' + name
def __get__(self, obj, objtype=None):
return getattr(obj, self.private_name)
def __set__(self, obj, val):
setattr(
obj, self.private_name,
val if isinstance(val, FloatWithEq) else FloatWithEq(val)
)
And then this can be used after a similar fashion, such as below:
from __future__ import annotations # not needed in PY 3.10+
import pickle
from dataclasses import dataclass
@dataclass
class MyClass:
a: FloatWithEq | float | int | str = FloatWithEqDescriptor()
b: FloatWithEq | float | int | str = FloatWithEqDescriptor()
mc = MyClass(float('nan'), 1.23)
# Serialize and deserialize
mc2 = pickle.loads(pickle.dumps(mc))
assert mc2 == mc # it works!