How can I define algebraic data types in Python (2 or 3)?
-
1Python has a very loose type system in the first place; what exactly are you trying to get out of this? – Amber Apr 28 '13 at 01:18
-
11@Amber not loose. very strong, but duck. – Elazar Apr 28 '13 at 01:19
-
6@Elazar When I say "loose", I mean things like functions not having particular type signatures. But you're right, Python is not weakly typed. – Amber Apr 28 '13 at 01:25
-
3@Amber I see what you mean. – Elazar Apr 28 '13 at 01:28
-
3@Amber Rather than "loose" maybe saying it's dynamically typed rather than statically typed would have been clearer. – Dwayne Crooks Jan 08 '17 at 16:23
-
@Elazar But there also some amount implicit type conversions. – Ekrem Dinçel Jan 08 '21 at 18:09
-
Type conversions such as? – Elazar Jan 09 '21 at 16:01
3 Answers
Python 3.10 version
Here is a Python 3.10 version of Brent's answer with pattern-matching and prettier union type syntax:
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Point | Circle | Rectangle
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case Rectangle(x, y, w, h):
print(f"Rectangle {x} {y} {w} {h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4) # mypy type error
You can even do recursive types:
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Branch:
value: int
left: Tree
right: Tree
Tree = Branch | None
def contains(tree: Tree, value: int):
match tree:
case None:
return False
case Branch(x, left, right):
return x == value or contains(left, value) or contains(right, value)
tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))
assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)
Note the need for from __future__ import annotations
in order to annotate with a type that hasn't been defined yet.
Exhaustiveness checking for ADTs can be enforced with mypy
using typing.assert_never()
in Python 3.11+ or as part of the typing-extensions
backport for older versions of Python.
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case _ as unreachable:
# mypy will throw a type checking error
# because Rectangle is not covered in the match.
assert_never(unreachable)
-
Great answer! Not that [python does not optimize tail recursion](https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion) and the `contains` function will use one stack frame per tree level. Given the logarithmic height of binary trees, this should be fine, just something to be aware of – Simon Kohlmeyer Oct 24 '22 at 16:34
-
1@SimonKohlmeyer No language, even ones with tail call optimization/elimination, should be able to eliminate the tail in this case, as far as I can see. The `contains` branches *twice*, so only `contains(right, value)` is in tail position, while the prior one, `contains(left, value)` is not. The `or` operation follows it in any case and no tail can be eliminated (the locals are still required after returning from it). So this always uses a stack frame per level, regardless if Python supported TCO. – Alex Povel Jun 15 '23 at 09:54
The typing
module provides Union
which, dissimilar to C, is a sum type. You'll need to use mypy to do static type checking, and there's a notable lack of pattern matching, but combined with tuples (product types), that's the two common algebraic types.
from dataclasses import dataclass
from typing import Union
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Union[Point, Circle, Rectangle]
def print_shape(shape: Shape):
if isinstance(shape, Point):
print(f"Point {shape.x} {shape.y}")
elif isinstance(shape, Circle):
print(f"Circle {shape.x} {shape.y} {shape.r}")
elif isinstance(shape, Rectangle):
print(f"Rectangle {shape.x} {shape.y} {shape.w} {shape.h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
# print_shape(4) # mypy type error

- 4,153
- 4
- 30
- 63
-
10Hi! If you are using mypy, you can check the exhaustiveness of your "pattern matching" using the `assert_never` idiom: https://github.com/python/typing/issues/735 – Noé Rubinstein Nov 05 '20 at 10:40
-
5And PEP 622 (pattern matching) makes using sum types even more similar to functional languages. – Martin Stancsics Dec 09 '20 at 10:52
-
4This isn't really sum types since it just relies on RTTI to identify the type. – Timmmm May 12 '21 at 13:56
-
5As of today (October 15 2021) Python 3.10 is released, adding support for structural pattern matching. https://docs.python.org/3.10/whatsnew/3.10.html#summary-release-highlights – Brent Oct 15 '21 at 19:17
Here's an implementation of sum types in relatively Pythonic way.
import attr
@attr.s(frozen=True)
class CombineMode(object):
kind = attr.ib(type=str)
params = attr.ib(factory=list)
def match(self, expected_kind, f):
if self.kind == expected_kind:
return f(*self.params)
else:
return None
@classmethod
def join(cls):
return cls("join")
@classmethod
def select(cls, column: str):
return cls("select", params=[column])
Crack open an interpreter and you'll see familiar behavior:
>>> CombineMode.join()
CombineMode(kind='join_by_entity', params=[])
>>> CombineMode.select('a') == CombineMode.select('b')
False
>>> CombineMode.select('a') == CombineMode.select('a')
True
>>> CombineMode.select('foo').match('select', print)
foo
Note: The @attr.s
decorator comes from the attrs library, it implements __init__
, __repr__
, and __eq__
, but it also freezes the object. I included it because it cuts down on the implementation size, but it's also widely available and quite stable.
Sum types are sometimes called tagged unions. Here I used the kind
member to implement the tag. Additional per-variant parameters are implemented via a list. In true Pythonic fashion, this is duck-typed on the input & output sides but not strictly enforced internally.
I also included a match
function that does basic pattern matching. Type safety is also implemented via duck typing, a TypeError
will be raised if the passed lambda's function signature doesn't align with the actual variant you're trying to match on.
These sum types can be combined with product types (list
or tuple
) and still retain a lot of the critical functionality required for algebraic data types.
Problems
This doesn't strictly constrain the set of variants.

- 8,705
- 5
- 46
- 82