67

How can I define algebraic data types in Python (2 or 3)?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
user1234299
  • 1,017
  • 1
  • 8
  • 14

3 Answers3

44

Python 3.10 version

Here is a Python 3.10 version of Brent's answer with pattern-matching and prettier union type syntax:

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float

@dataclass
class Circle:
    x: float
    y: float
    r: float

@dataclass
class Rectangle:
    x: float
    y: float
    w: float
    h: float

Shape = Point | Circle | Rectangle

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case Rectangle(x, y, w, h):
            print(f"Rectangle {x} {y} {w} {h}")

print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4)  # mypy type error

You can even do recursive types:

from __future__ import annotations
from dataclasses import dataclass

@dataclass
class Branch:
    value: int
    left: Tree
    right: Tree

Tree = Branch | None

def contains(tree: Tree, value: int):
    match tree:
        case None:
            return False
        case Branch(x, left, right):
            return x == value or contains(left, value) or contains(right, value)

tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))

assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)

Note the need for from __future__ import annotations in order to annotate with a type that hasn't been defined yet.

Exhaustiveness checking for ADTs can be enforced with mypy using typing.assert_never() in Python 3.11+ or as part of the typing-extensions backport for older versions of Python.

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case _ as unreachable:
            # mypy will throw a type checking error
            # because Rectangle is not covered in the match.
            assert_never(unreachable)
phoenix
  • 7,988
  • 6
  • 39
  • 45
tehziyang
  • 461
  • 4
  • 5
  • Great answer! Not that [python does not optimize tail recursion](https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion) and the `contains` function will use one stack frame per tree level. Given the logarithmic height of binary trees, this should be fine, just something to be aware of – Simon Kohlmeyer Oct 24 '22 at 16:34
  • 1
    @SimonKohlmeyer No language, even ones with tail call optimization/elimination, should be able to eliminate the tail in this case, as far as I can see. The `contains` branches *twice*, so only `contains(right, value)` is in tail position, while the prior one, `contains(left, value)` is not. The `or` operation follows it in any case and no tail can be eliminated (the locals are still required after returning from it). So this always uses a stack frame per level, regardless if Python supported TCO. – Alex Povel Jun 15 '23 at 09:54
24

The typing module provides Union which, dissimilar to C, is a sum type. You'll need to use mypy to do static type checking, and there's a notable lack of pattern matching, but combined with tuples (product types), that's the two common algebraic types.

from dataclasses import dataclass
from typing import Union


@dataclass
class Point:
    x: float
    y: float


@dataclass
class Circle:
    x: float
    y: float
    r: float


@dataclass
class Rectangle:
    x: float
    y: float
    w: float
    h: float


Shape = Union[Point, Circle, Rectangle]


def print_shape(shape: Shape):
    if isinstance(shape, Point):
        print(f"Point {shape.x} {shape.y}")
    elif isinstance(shape, Circle):
        print(f"Circle {shape.x} {shape.y} {shape.r}")
    elif isinstance(shape, Rectangle):
        print(f"Rectangle {shape.x} {shape.y} {shape.w} {shape.h}")


print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
# print_shape(4)  # mypy type error
Brent
  • 4,153
  • 4
  • 30
  • 63
  • 10
    Hi! If you are using mypy, you can check the exhaustiveness of your "pattern matching" using the `assert_never` idiom: https://github.com/python/typing/issues/735 – Noé Rubinstein Nov 05 '20 at 10:40
  • 5
    And PEP 622 (pattern matching) makes using sum types even more similar to functional languages. – Martin Stancsics Dec 09 '20 at 10:52
  • 4
    This isn't really sum types since it just relies on RTTI to identify the type. – Timmmm May 12 '21 at 13:56
  • 5
    As of today (October 15 2021) Python 3.10 is released, adding support for structural pattern matching. https://docs.python.org/3.10/whatsnew/3.10.html#summary-release-highlights – Brent Oct 15 '21 at 19:17
-3

Here's an implementation of sum types in relatively Pythonic way.

import attr


@attr.s(frozen=True)
class CombineMode(object):
    kind = attr.ib(type=str)
    params = attr.ib(factory=list)

    def match(self, expected_kind, f):
        if self.kind == expected_kind:
            return f(*self.params)
        else:
            return None

    @classmethod
    def join(cls):
        return cls("join")

    @classmethod
    def select(cls, column: str):
        return cls("select", params=[column])

Crack open an interpreter and you'll see familiar behavior:

>>> CombineMode.join()
CombineMode(kind='join_by_entity', params=[])

>>> CombineMode.select('a') == CombineMode.select('b')
False

>>> CombineMode.select('a') == CombineMode.select('a')
True

>>> CombineMode.select('foo').match('select', print)
foo

Note: The @attr.s decorator comes from the attrs library, it implements __init__, __repr__, and __eq__, but it also freezes the object. I included it because it cuts down on the implementation size, but it's also widely available and quite stable.

Sum types are sometimes called tagged unions. Here I used the kind member to implement the tag. Additional per-variant parameters are implemented via a list. In true Pythonic fashion, this is duck-typed on the input & output sides but not strictly enforced internally.

I also included a match function that does basic pattern matching. Type safety is also implemented via duck typing, a TypeError will be raised if the passed lambda's function signature doesn't align with the actual variant you're trying to match on.

These sum types can be combined with product types (list or tuple) and still retain a lot of the critical functionality required for algebraic data types.

Problems

This doesn't strictly constrain the set of variants.

kelloti
  • 8,705
  • 5
  • 46
  • 82