0

I'd like to copy an instance of a frozen dataclass, changing just one field ("functional update").

Here's what I tried

from dataclasses import dataclass, asdict    
@dataclass(frozen = True)    
class Pos:    
    start: int    
    end: int    
     
def adjust_start(pos: Pos, delta: int) -> Pos:    
   # TypeError: type object got multiple values for keyword argument 'start'    
   return Pos(**asdict(pos), start = pos.start + delta)    
     
adjust_start(Pos(1, 2), 4)   

What I'm looking for:

  • Is there a more straightforward way than converting to/from dicts?
  • How to get around the TypeError: if there is a way to functionally update kwargs then that could work.

In Scala, a functional update of a case class (Scala dataclass) can be done like this: pos.copy(start = pos.start + delta).

Max Heiber
  • 14,346
  • 12
  • 59
  • 97

2 Answers2

5

dataclasses.replace() to the rescue.

dataclasses.replace(obj, /, **changes) creates a new object of the same type as obj, replacing fields with values from changes.

import dataclasses


@dataclasses.dataclass(frozen=True)
class Pos:
    start: int
    end: int


def adjust_start(pos: Pos, delta: int) -> Pos:
    return dataclasses.replace(pos, start=pos.start + delta)


p = adjust_start(Pos(1, 2), 4)

Personally, I might put adjust on the dataclass itself:

import dataclasses


@dataclasses.dataclass(frozen=True)
class Pos:
    start: int
    end: int

    def adjust(self, *, start: int, end: int) -> "Pos":
        return dataclasses.replace(
            self,
            start=self.start + start,
            end=self.end + end,
        )


p = Pos(1, 2).adjust(start=4)
AKX
  • 152,115
  • 15
  • 115
  • 172
0

You can use a @cached_class_property approach with dataclasses.fields().

For example:

from dataclasses import fields, dataclass


class cached_class_property(object):
    """
    Descriptor decorator implementing a class-level, read-only property,
    which caches the attribute on-demand on the first use.

    Credits: https://stackoverflow.com/a/4037979/10237506
    """
    def __init__(self, func):
        self.__func__ = func
        self.__attr_name__ = func.__name__

    def __get__(self, instance, cls=None):
        """This method is only called the first time, to cache the value."""
        if cls is None:
            cls = type(instance)

        # Build the attribute.
        attr = self.__func__(cls)

        # Cache the value; hide ourselves.
        setattr(cls, self.__attr_name__, attr)

        return attr


@dataclass(frozen=True)
class Pos:
    start: int
    end: int

    @cached_class_property
    def init_fields(cls):
        return tuple(f.name for f in fields(cls) if f.init)

    def adjust_start(self, delta: int) -> 'Pos':
        attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
                 for f in Pos.init_fields]

        return Pos(*attrs)


p1 = Pos(1, 2)
print(p1)

p2 = Pos(1, 2).adjust_start(4)
print(p2)

As you are using a frozen=True dataclass and with slots=False, you could also simplify this approach, i.e. without the use of @cached_class_property:

def adjust_start(self, delta: int) -> 'Pos':
    _dict = self.__dict__.copy()
    _dict['start'] += delta

    return Pos(*_dict.values())

Output:

Pos(start=1, end=2)
Pos(start=5, end=2)

Results show it's ever slightly faster than dataclasses.replace():

from dataclasses import fields, dataclass, replace
from timeit import timeit


@dataclass(frozen=True)
class Pos:
    start: int
    end: int

    # `cached_class_property` defined from above
    @cached_class_property
    def init_fields(cls):
        return tuple(f.name for f in fields(cls) if f.init)

    def adjust_via_copy(self, delta: int) -> 'Pos':
        _dict = self.__dict__.copy()
        _dict['start'] += delta

        return Pos(*_dict.values())

    def adjust_via_fields(self, delta: int) -> 'Pos':
        attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
                 for f in Pos.init_fields]

        return Pos(*attrs)

    def adjust_via_replace(self, delta: int) -> 'Pos':
        return replace(
            self,
            start=self.start + delta,
        )


p = Pos(1, 2)

print('o.__dict__.copy:      ', round(timeit('p.adjust_via_copy(4)', globals=globals()), 3))
print('dataclasses.fields:   ', round(timeit('p.adjust_via_fields(4)', globals=globals()), 3))
print('dataclasses.replace:  ', round(timeit('p.adjust_via_replace(4)', globals=globals()), 3))

assert Pos(-2, 2) == p.adjust_via_replace(-3) == p.adjust_via_fields(-3) == p.adjust_via_replace(-3)
o.__dict__.copy:       0.408
dataclasses.fields:    0.499
dataclasses.replace:   0.659
rv.kvetch
  • 9,940
  • 3
  • 24
  • 53