You can use a @cached_class_property
approach with dataclasses.fields()
.
For example:
from dataclasses import fields, dataclass
class cached_class_property(object):
"""
Descriptor decorator implementing a class-level, read-only property,
which caches the attribute on-demand on the first use.
Credits: https://stackoverflow.com/a/4037979/10237506
"""
def __init__(self, func):
self.__func__ = func
self.__attr_name__ = func.__name__
def __get__(self, instance, cls=None):
"""This method is only called the first time, to cache the value."""
if cls is None:
cls = type(instance)
# Build the attribute.
attr = self.__func__(cls)
# Cache the value; hide ourselves.
setattr(cls, self.__attr_name__, attr)
return attr
@dataclass(frozen=True)
class Pos:
start: int
end: int
@cached_class_property
def init_fields(cls):
return tuple(f.name for f in fields(cls) if f.init)
def adjust_start(self, delta: int) -> 'Pos':
attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
for f in Pos.init_fields]
return Pos(*attrs)
p1 = Pos(1, 2)
print(p1)
p2 = Pos(1, 2).adjust_start(4)
print(p2)
As you are using a frozen=True
dataclass and with slots=False
, you could also simplify this approach, i.e. without the use of @cached_class_property
:
def adjust_start(self, delta: int) -> 'Pos':
_dict = self.__dict__.copy()
_dict['start'] += delta
return Pos(*_dict.values())
Output:
Pos(start=1, end=2)
Pos(start=5, end=2)
Results show it's ever slightly faster than dataclasses.replace()
:
from dataclasses import fields, dataclass, replace
from timeit import timeit
@dataclass(frozen=True)
class Pos:
start: int
end: int
# `cached_class_property` defined from above
@cached_class_property
def init_fields(cls):
return tuple(f.name for f in fields(cls) if f.init)
def adjust_via_copy(self, delta: int) -> 'Pos':
_dict = self.__dict__.copy()
_dict['start'] += delta
return Pos(*_dict.values())
def adjust_via_fields(self, delta: int) -> 'Pos':
attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
for f in Pos.init_fields]
return Pos(*attrs)
def adjust_via_replace(self, delta: int) -> 'Pos':
return replace(
self,
start=self.start + delta,
)
p = Pos(1, 2)
print('o.__dict__.copy: ', round(timeit('p.adjust_via_copy(4)', globals=globals()), 3))
print('dataclasses.fields: ', round(timeit('p.adjust_via_fields(4)', globals=globals()), 3))
print('dataclasses.replace: ', round(timeit('p.adjust_via_replace(4)', globals=globals()), 3))
assert Pos(-2, 2) == p.adjust_via_replace(-3) == p.adjust_via_fields(-3) == p.adjust_via_replace(-3)
o.__dict__.copy: 0.408
dataclasses.fields: 0.499
dataclasses.replace: 0.659