2

I have a Pydantic model class MessageModel with a version number as Literal. Now our requirements have changed and we need another MessageModel with a higher version number, because the attributes of the MessageModel have changed. I want to have a class, where I can give the version number as an argument to the constructor. Does anyone have an idea?

Here are the models:

from typing import Literal
from pydantic import BaseModel


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str
        
class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str

What I want is a class which initializes the right MessageModel version:

model = MessageModel(version=2, ...)
Daniil Fajnberg
  • 12,753
  • 2
  • 10
  • 41
Phil997
  • 575
  • 5
  • 15
  • Does this answer your question? [Using different Pydantic models depending on the value of fields](https://stackoverflow.com/questions/71539448/using-different-pydantic-models-depending-on-the-value-of-fields) – Daniil Fajnberg Jun 23 '23 at 08:06
  • Not really, in this example fastapi is used, which tries both and one succeeds. But what I need is a class that, when I initialise it, automatically selects the right model. – Phil997 Jun 23 '23 at 08:10
  • Fair enough. Though the [second answer](https://stackoverflow.com/a/76449131) provides a solution independent from FastAPI. But the question itself was specifically about FastAPI routes, so yours is not exactly a duplicate. I provided a more in-depth answer with two different approaches just for Pydantic below. – Daniil Fajnberg Jun 23 '23 at 09:09

1 Answers1

4

Discriminated union

You can define a discriminated union of two or more Pydantic models. This will allow you to instantiate the correct model based on the value of the discriminator field provided in the data.

There are a few slightly different approaches you can take.


Option A: Annotated type alias

Instead of defining a new model that "combines" your existing ones, you define a type alias for the union of those models and use typing.Annotated to add the discriminator information.

# Pydantic v1

from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field, parse_obj_as


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageModel = Annotated[
    Union[MessageModelV1, MessageModelV2],
    Field(discriminator="version"),
]


data1 = {"version": 1, "bar": "a"}
data2 = {"version": 2, "foo": "b"}
obj1 = parse_obj_as(MessageModel, data1)
obj2 = parse_obj_as(MessageModel, data2)
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModelV2'>

Pros

  • Minimal additional code, just a type definition.
  • You get an actual instance of one of the original models.
  • Static type checkers can infer the type as a union of the original models. (This gives IDEs the ability to provide some useful auto-suggestions.)

Cons

  • You cannot directly instantiate MessageModel because it is a typing construct and not a model class.
  • To construct an instance you must use the additional parse_obj_as function.
  • There is no way to statically determine, which exact model will come out after parsing because it depends on the data.

UPDATE: Pydantic v2

Very similar but using the TypeAdapter instead of parse_as_obj:

# Pydantic v2

from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field, TypeAdapter


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageModel = TypeAdapter(Annotated[
    Union[MessageModelV1, MessageModelV2],
    Field(discriminator="version"),
])


data1 = {"version": 1, "bar": "a"}
data2 = {"version": 2, "foo": "b"}
obj1 = MessageModel.validate_python(data1)
obj2 = MessageModel.validate_python(data2)
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModelV2'>

Unfortunately, as of now, Mypy seems unable to infer the type of the resulting model instances correctly, while Pyright does.


Option B: Custom root type

You define a new model and set its __root__ type to the discriminated union between of the original models.

Then you can customize it to the degree you see fit, in order to make instance of it "feel" like any of the original underlying models.

In the following example I overrode the __init__ method so that you can initialize it like you described in your question. I also made it iterable via the underlying root model, and I added custom __str__ and __repr__ methods so that instances are displayed as though they are of the underlying root type.

# Pydantic v1

from collections.abc import Iterator
from typing import Any, Literal, Union
from pydantic import BaseModel, Field


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageType = Union[MessageModelV1, MessageModelV2]


class MessageModel(BaseModel):
    __root__: MessageType = Field(discriminator="version")

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(__root__=kwargs)

    def __iter__(self) -> Iterator[tuple[str, Any]]:  # type: ignore[override]
        yield from self.__root__

    def __str__(self) -> str:
        return str(self.__root__)

    def __repr__(self) -> str:
        return repr(self.__root__)

With the same demo data, you now get slightly different results:

# Pydantic v1

obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModel'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModel'>

print(hasattr(obj1, "__root__"))  # True
print(hasattr(obj1, "bar"))       # False
print(hasattr(obj1, "version"))   # False

You'll notice that the type of the object is of course the new MessageModel now and not one of the underlying original models.

Also, without doing more customization, instances will only have the __root__ field pointing to an instance of the underlying model and not have that model's actual fields. If you want to pass on attribute access to the root model, you would have to override __getattr__/__setattr__:

# Pydantic v1

...

class MessageModel(BaseModel):
    __root__: MessageType = Field(discriminator="version")

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(__root__=kwargs)

    def __getattr__(self, name: str) -> Any:
        if name == "__root__":
            return self.__root__
        return getattr(self.__root__, name)

    def __setattr__(self, name: str, value: Any) -> None:
        if name == "__root__":
            self.__root__ = value
        setattr(self.__root__, name, value)

    ...
# Pydantic v1

obj1 = MessageModel(version=1, bar="a")

print(obj1.version)  # 1
print(obj1.bar)      # a

Pros

  • Direct instantiation is possible because you are dealing with an actual model class.
  • Lots of options for customization and additional validation.

Cons

  • Without any customization you will always deal with a "proxy" model and its __root__ field will point to the "actual" model.
  • Making the outer model feel/behave like one of the underlying models takes a lot of additional boilerplate. You could probably factor all/most of it out into a separate mix-in class, but the code needs to be written nonetheless.
  • You might "forget" that you are dealing with the proxy model, not with the root model, and then run into unexpected results. (For example, try to call obj1.dict() with the last example and look at the output.)
  • The type inferred by static analysis will always be MessageModel.

PS

You could hack around the fact that the outer MessageModel is not actually a subtype of the underlying models with typing.TYPE_CHECKING a bit. Example:

# Pydantic v1

from typing import Any, Literal, Union, TYPE_CHECKING
from pydantic import BaseModel, Field


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageType = Union[MessageModelV1, MessageModelV2]

if TYPE_CHECKING:
    class MessageModel(MessageModelV1, MessageModelV2):
        pass
else:
    class MessageModel(BaseModel):
        __root__: MessageType = Field(discriminator="version")

        def __init__(self, **kwargs: Any) -> None:
            super().__init__(__root__=kwargs)

        ...

obj1 = MessageModel(version=1, bar="a")

From a static analysis point of view, MessageModel is now a subtype of the other two, which means your IDE would give you the insights accordingly. Type in obj1. and PyCharm for example will suggest both the attributes foo and bar. Typing-wise this makes it essentially equivalent to the Union alias approach (Option A).

But of course this is a lie since the actual runtime MessageModel is not a subclass of those two. But if you know exactly what interface you need to make it "feel" like a subclass, you can make it work.

I would not recommend going that far with the obfuscation.

UPDATE: Pydantic v2

Very similar but using the RootModel instead:

# Pydantic v2

from collections.abc import Iterator
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, RootModel


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageType = Union[MessageModelV1, MessageModelV2]


class MessageModel(RootModel):
    root: MessageType = Field(discriminator="version")

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(root=kwargs)

    def __iter__(self) -> Iterator[tuple[str, Any]]:  # type: ignore[override]
        yield from self.root

    def __str__(self) -> str:
        return str(self.root)

    def __repr__(self) -> str:
        return repr(self.root)


obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModel'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModel'>

print(hasattr(obj1, "root"))     # True
print(hasattr(obj1, "bar"))      # False
print(hasattr(obj1, "version"))  # False

Option C: Union field + proxy constructor

Essentially based on Option B, but instead of going through all that trouble of defining a proxy-interface for the nested model, you just write a custom constructor function that looks like a class, and make it instantiate the outer model, but return just its __root__ value:

# Pydantic v1

from typing import Any, Literal, Union
from pydantic import BaseModel, Field


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageType = Union[MessageModelV1, MessageModelV2]


class OuterMessageModel(BaseModel):
    __root__: MessageType = Field(discriminator="version")


def MessageModel(**kwargs: Any) -> MessageType:
    return OuterMessageModel.parse_obj(kwargs).__root__


obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModelV2'>

Pros

  • Fairly simple; not a lot of code.
  • Direct instantiation is possible (albeit via a proxy-constructor).
  • You get an actual instance of one of the original models.
  • Static type checkers can infer the type as a union of the original models. (This gives IDEs the ability to provide some useful auto-suggestions.)

Cons

  • The constructor function looks like a "fake" class. (Though this is a fairly common approach; see for example the Pydantic Field function.)
  • There is still no way to statically determine, which exact model will come out after parsing because it depends on the data.

UPDATE: Pydantic v2

Very similar but using the RootModel instead:

# Pydantic v2

from typing import Any, Literal, Union
from pydantic import BaseModel, Field, RootModel


class MessageModelV1(BaseModel):
    version: Literal[1]
    bar: str


class MessageModelV2(BaseModel):
    version: Literal[2]
    foo: str


MessageType = Union[MessageModelV1, MessageModelV2]


class OuterMessageModel(RootModel):
    root: MessageType = Field(discriminator="version")


def MessageModel(**kwargs: Any) -> MessageType:
    return OuterMessageModel.model_validate(kwargs).root


obj1 = MessageModel(version=1, bar="a")
obj2 = MessageModel(version=2, foo="b")
print(obj1, type(obj1))  # version=1 bar='a' <class '__main__.MessageModelV1'>
print(obj2, type(obj2))  # version=2 foo='b' <class '__main__.MessageModelV2'>
Daniil Fajnberg
  • 12,753
  • 2
  • 10
  • 41
  • First of all, thank you for your detailed answer, I wasn't expecting something so detailed. Your second suggestion with the root type looks like something I am looking for. But as you mentioned, the programmer has to use the __root__ as a proxy. It would be nicer if the MessageModel had the attributes foo or bar, with annotation for the interpreter. I had already found the first possibility parse_obj_as, but I don't find it so pretty. – Phil997 Jun 23 '23 at 11:52
  • @Phil997 I added another option/variation to my answer. Maybe you'll prefer that one. I don't think there is a "perfect" way to accomplish this without any drawbacks. – Daniil Fajnberg Jun 23 '23 at 12:27