Discriminated union
You can define a discriminated union of two or more Pydantic models. This will allow you to instantiate the correct model based on the value of the discriminator field provided in the data.
There are a few slightly different approaches you can take.
Option A: Annotated type alias
Instead of defining a new model that "combines" your existing ones, you define a type alias for the union of those models and use typing.Annotated
to add the discriminator information.
# Pydantic v1
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field, parse_obj_as
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageModel = Annotated[
Union[MessageModelV1, MessageModelV2],
Field(discriminator="version"),
]
data1 = {"version": 1, "bar": "a"}
data2 = {"version": 2, "foo": "b"}
obj1 = parse_obj_as(MessageModel, data1)
obj2 = parse_obj_as(MessageModel, data2)
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModelV2'>
Pros
- Minimal additional code, just a type definition.
- You get an actual instance of one of the original models.
- Static type checkers can infer the type as a union of the original models. (This gives IDEs the ability to provide some useful auto-suggestions.)
Cons
- You cannot directly instantiate
MessageModel
because it is a typing construct and not a model class.
- To construct an instance you must use the additional
parse_obj_as
function.
- There is no way to statically determine, which exact model will come out after parsing because it depends on the data.
UPDATE: Pydantic v2
Very similar but using the TypeAdapter
instead of parse_as_obj
:
# Pydantic v2
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field, TypeAdapter
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageModel = TypeAdapter(Annotated[
Union[MessageModelV1, MessageModelV2],
Field(discriminator="version"),
])
data1 = {"version": 1, "bar": "a"}
data2 = {"version": 2, "foo": "b"}
obj1 = MessageModel.validate_python(data1)
obj2 = MessageModel.validate_python(data2)
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModelV2'>
Unfortunately, as of now, Mypy seems unable to infer the type of the resulting model instances correctly, while Pyright does.
Option B: Custom root type
You define a new model and set its __root__
type to the discriminated union between of the original models.
Then you can customize it to the degree you see fit, in order to make instance of it "feel" like any of the original underlying models.
In the following example I overrode the __init__
method so that you can initialize it like you described in your question. I also made it iterable via the underlying root model, and I added custom __str__
and __repr__
methods so that instances are displayed as though they are of the underlying root type.
# Pydantic v1
from collections.abc import Iterator
from typing import Any, Literal, Union
from pydantic import BaseModel, Field
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageType = Union[MessageModelV1, MessageModelV2]
class MessageModel(BaseModel):
__root__: MessageType = Field(discriminator="version")
def __init__(self, **kwargs: Any) -> None:
super().__init__(__root__=kwargs)
def __iter__(self) -> Iterator[tuple[str, Any]]: # type: ignore[override]
yield from self.__root__
def __str__(self) -> str:
return str(self.__root__)
def __repr__(self) -> str:
return repr(self.__root__)
With the same demo data, you now get slightly different results:
# Pydantic v1
obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModel'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModel'>
print(hasattr(obj1, "__root__")) # True
print(hasattr(obj1, "bar")) # False
print(hasattr(obj1, "version")) # False
You'll notice that the type of the object is of course the new MessageModel
now and not one of the underlying original models.
Also, without doing more customization, instances will only have the __root__
field pointing to an instance of the underlying model and not have that model's actual fields. If you want to pass on attribute access to the root model, you would have to override __getattr__
/__setattr__
:
# Pydantic v1
...
class MessageModel(BaseModel):
__root__: MessageType = Field(discriminator="version")
def __init__(self, **kwargs: Any) -> None:
super().__init__(__root__=kwargs)
def __getattr__(self, name: str) -> Any:
if name == "__root__":
return self.__root__
return getattr(self.__root__, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == "__root__":
self.__root__ = value
setattr(self.__root__, name, value)
...
# Pydantic v1
obj1 = MessageModel(version=1, bar="a")
print(obj1.version) # 1
print(obj1.bar) # a
Pros
- Direct instantiation is possible because you are dealing with an actual model class.
- Lots of options for customization and additional validation.
Cons
- Without any customization you will always deal with a "proxy" model and its
__root__
field will point to the "actual" model.
- Making the outer model feel/behave like one of the underlying models takes a lot of additional boilerplate. You could probably factor all/most of it out into a separate mix-in class, but the code needs to be written nonetheless.
- You might "forget" that you are dealing with the proxy model, not with the root model, and then run into unexpected results. (For example, try to call
obj1.dict()
with the last example and look at the output.)
- The type inferred by static analysis will always be
MessageModel
.
PS
You could hack around the fact that the outer MessageModel
is not actually a subtype of the underlying models with typing.TYPE_CHECKING
a bit. Example:
# Pydantic v1
from typing import Any, Literal, Union, TYPE_CHECKING
from pydantic import BaseModel, Field
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageType = Union[MessageModelV1, MessageModelV2]
if TYPE_CHECKING:
class MessageModel(MessageModelV1, MessageModelV2):
pass
else:
class MessageModel(BaseModel):
__root__: MessageType = Field(discriminator="version")
def __init__(self, **kwargs: Any) -> None:
super().__init__(__root__=kwargs)
...
obj1 = MessageModel(version=1, bar="a")
From a static analysis point of view, MessageModel
is now a subtype of the other two, which means your IDE would give you the insights accordingly. Type in obj1.
and PyCharm for example will suggest both the attributes foo
and bar
. Typing-wise this makes it essentially equivalent to the Union
alias approach (Option A).
But of course this is a lie since the actual runtime MessageModel
is not a subclass of those two. But if you know exactly what interface you need to make it "feel" like a subclass, you can make it work.
I would not recommend going that far with the obfuscation.
UPDATE: Pydantic v2
Very similar but using the RootModel
instead:
# Pydantic v2
from collections.abc import Iterator
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, RootModel
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageType = Union[MessageModelV1, MessageModelV2]
class MessageModel(RootModel):
root: MessageType = Field(discriminator="version")
def __init__(self, **kwargs: Any) -> None:
super().__init__(root=kwargs)
def __iter__(self) -> Iterator[tuple[str, Any]]: # type: ignore[override]
yield from self.root
def __str__(self) -> str:
return str(self.root)
def __repr__(self) -> str:
return repr(self.root)
obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModel'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModel'>
print(hasattr(obj1, "root")) # True
print(hasattr(obj1, "bar")) # False
print(hasattr(obj1, "version")) # False
Option C: Union field + proxy constructor
Essentially based on Option B, but instead of going through all that trouble of defining a proxy-interface for the nested model, you just write a custom constructor function that looks like a class, and make it instantiate the outer model, but return just its __root__
value:
# Pydantic v1
from typing import Any, Literal, Union
from pydantic import BaseModel, Field
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageType = Union[MessageModelV1, MessageModelV2]
class OuterMessageModel(BaseModel):
__root__: MessageType = Field(discriminator="version")
def MessageModel(**kwargs: Any) -> MessageType:
return OuterMessageModel.parse_obj(kwargs).__root__
obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModelV2'>
Pros
- Fairly simple; not a lot of code.
- Direct instantiation is possible (albeit via a proxy-constructor).
- You get an actual instance of one of the original models.
- Static type checkers can infer the type as a union of the original models. (This gives IDEs the ability to provide some useful auto-suggestions.)
Cons
- The constructor function looks like a "fake" class. (Though this is a fairly common approach; see for example the Pydantic
Field
function.)
- There is still no way to statically determine, which exact model will come out after parsing because it depends on the data.
UPDATE: Pydantic v2
Very similar but using the RootModel
instead:
# Pydantic v2
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, RootModel
class MessageModelV1(BaseModel):
version: Literal[1]
bar: str
class MessageModelV2(BaseModel):
version: Literal[2]
foo: str
MessageType = Union[MessageModelV1, MessageModelV2]
class OuterMessageModel(RootModel):
root: MessageType = Field(discriminator="version")
def MessageModel(**kwargs: Any) -> MessageType:
return OuterMessageModel.model_validate(kwargs).root
obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1)) # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2)) # version=2 foo='b' <class '__main__.MessageModelV2'>