6

I am building a FastAPI application, which has a lot of Pydantic models. Even though the application is working just fine, as expected the OpenAPI (Swagger UI) docs do not show the schema for all of these models under the Schemas section.

Here are the contents of pydantic schemas.py

import socket
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, validator
from typing_extensions import Literal

ResponseData = Union[List[Any], Dict[str, Any], BaseModel]


# Not visible in Swagger UI
class PageIn(BaseModel):
    page_size: int = Field(default=100, gt=0)
    num_pages: int = Field(default=1, gt=0, exclude=True)
    start_page: int = Field(default=1, gt=0, exclude=True)

# visible under schemas on Swagger UI
class PageOut(PageIn):
    total_records: int = 0
    total_pages: int = 0
    current_page: int = 1

    class Config:  # pragma: no cover
        @staticmethod
        def schema_extra(schema, model) -> None:
            schema.get("properties").pop("num_pages")
            schema.get("properties").pop("start_page")


# Not visible in Swagger UI
class BaseResponse(BaseModel):
    host_: str = Field(default_factory=socket.gethostname)
    message: Optional[str]


# Not visible in Swagger UI
class APIResponse(BaseResponse):
    count: int = 0
    location: Optional[str]
    page: Optional[PageOut]
    data: ResponseData


# Not visible in Swagger UI
class ErrorResponse(BaseResponse):
    error: str


# visible under schemas on Swagger UI
class BaseFaultMap(BaseModel):
    detection_system: Optional[str] = Field("", example="obhc")
    fault_type: Optional[str] = Field("", example="disk")
    team: Optional[str] = Field("", example="dctechs")
    description: Optional[str] = Field(
        "",
        example="Hardware raid controller disk failure found. "
        "Operation can continue normally,"
        "but risk of data loss exist",
    )



# Not visible in Swagger UI
class FaultQueryParams(BaseModel):
    f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
    hostname: Optional[str]
    status: Literal["open", "closed", "all"] = Field("open")
    created_by: Optional[str]
    environment: Optional[str]
    team: Optional[str]
    fault_type: Optional[str]
    detection_system: Optional[str]
    inops_filters: Optional[str] = Field(None)
    date_filter: Optional[str] = Field("",)
    sort_by: Optional[str] = Field("created",)
    sort_order: Literal["asc", "desc"] = Field("desc")

All of these models are actually being used in FastAPI paths to validate the request body. The FaultQueryParams is a custom model, which I use to validate the request query params and is used like below:

query_args: FaultQueryParams = Depends()

The rest of the models are being used in conjunction with Body field. I am not able to figure out why only some of the models are not visible in the Schemas section while others are.

Also another thing I noticed about FaultQueryParams is that the description, examples do not show up against the path endpoint even though they are defined in the model.

Edit 1:

I looked more into and realized that all of the models which are not visible in swagger UI are the ones that are not being used directly in path operations i.e., these models are not being used as response_model or Body types and are sort of helper models which are being used indirectly. So, it seems like FastAPI is not generating the schema for these models.

One exception to the above statement is query_args: FaultQueryParams = Depends() which is being used directly in a path operation to map the Query params for the endpoint against a custom model. This is a problem because swagger is not identifying the meta parameters like title, description, example from the fields of this model & not showing on the UI which is important for the users of this endpoint.

Is there a way to trick FastAPI to generate schema for the custom model FaultQueryParams just like it generates for Body, Query etc ?

Chris
  • 18,724
  • 6
  • 46
  • 80
Rohit
  • 3,659
  • 3
  • 35
  • 57

2 Answers2

4

FastAPI will generate schemas for models that are used either as a Request Body or Response Model. When declaring query_args: FaultQueryParams = Depends() (using Depends), your endpoint would not expect a request body, but rather query parameters; hence, FaultQueryParams would not be included in the schemas of the OpenAPI docs.

To add additional schemas, you could extend/modify the OpenAPI schema. Example is given below (make sure to add the code for modifying the schema after all routes have been defined, i.e., at the end of your code).

class FaultQueryParams(BaseModel):
    f_id: Optional[int] = Field(None, description="id for the host", example=12345, title="Fault ID")
    hostname: Optional[str]
    status: Literal["open", "closed", "all"] = Field("open")
    ...
    
@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args

def get_extra_schemas():
    return {
              "FaultQueryParams": {
                "title": "FaultQueryParams",
                "type": "object",
                "properties": {
                  "f_id": {
                    "title": "Fault ID",
                    "type": "integer",
                    "description": "id for the host",
                    "example": 12345
                  },
                  "hostname": {
                    "title": "Hostname",
                    "type": "string"
                  },
                  "status": {
                    "title": "Status",
                    "enum": [
                      "open",
                      "closed",
                      "all"
                    ],
                    "type": "string",
                    "default": "open"
                  },
                   ...
                }
              }
            }

from fastapi.openapi.utils import get_openapi

def custom_openapi():
    if app.openapi_schema:
        return app.openapi_schema
    openapi_schema = get_openapi(
        title="FastAPI",
        version="1.0.0",
        description="This is a custom OpenAPI schema",
        routes=app.routes,
    )
    new_schemas = openapi_schema["components"]["schemas"]
    new_schemas.update(get_extra_schemas())
    openapi_schema["components"]["schemas"] = new_schemas
    
    app.openapi_schema = openapi_schema
    return app.openapi_schema


app.openapi = custom_openapi

Some Helpful Notes

Note 1

Instead of manually typing the schema for the extra models that you would like to add to the docs, you can have FastAPI do that for you by adding to your code an endpoint (which you would subsequently remove, after getting the schema) using that model as a request body or response model, for example:

@app.post('/predict') 
def predict(query_args: FaultQueryParams):
    return query_args

Then, you can get the generated JSON schema at http://127.0.0.1:8000/openapi.json, as described in the documentation. From there, you can either copy and paste the schema of the model to your code and use it directly (as shown in the get_extra_schema() method above) or save it to a file and load the JSON data from the file, as demonstrated below:

import json
...

new_schemas = openapi_schema["components"]["schemas"]

with open('extra_schemas.json') as f:    
    extra_schemas = json.load(f)
    
new_schemas.update(extra_schemas)   
openapi_schema["components"]["schemas"] = new_schemas

...

Note 2

To declare metadata, such as description, example, etc, for your query parameter, you should define your parameter with Query instead of Field, and since you can't do that with Pydantic models, you could declare a custom dependency class, as decribed here and as shown below:

from fastapi import FastAPI, Query, Depends
from typing import Optional

class FaultQueryParams:
    def __init__(
        self,
        f_id: Optional[int] = Query(None, description="id for the host", example=12345)

    ):
        self.f_id = f_id

app = FastAPI()

@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args

The above can be re-written using the @dataclass decorator, as shown below:

from fastapi import FastAPI, Query, Depends
from typing import Optional
from dataclasses import dataclass

@dataclass
class FaultQueryParams:
    f_id: Optional[int] = Query(None, description="id for the host", example=12345)

app = FastAPI()

@app.post('/predict')
def predict(query_args: FaultQueryParams = Depends()):
    return query_args
Chris
  • 18,724
  • 6
  • 46
  • 80
2

Thank to @Chris for the pointers which ultimately led me to use dataclasses for defining query params in bulk and it just worked fine.

@dataclass
class FaultQueryParams1:
    f_id: Optional[int] = Query(None, description="id for the host", example=55555)
    hostname: Optional[str] = Query(None, example="test-host1.domain.com")
    status: Literal["open", "closed", "all"] = Query(
        None, description="fetch open/closed or all records", example="all"
    )
    created_by: Optional[str] = Query(
        None,
        description="fetch records created by particular user",
        example="user-id",
    )
Rohit
  • 3,659
  • 3
  • 35
  • 57