I have the same need, and as my files were relatively large, I wanted to be able to get the error message before uploading the file to the backend (at least, not the whole file), as mentioned by rezan21.
This is the approach that made it work. Please note that there are multiple workarounds due to some Stalette limitations as 1. this one for reading the request body async generator and 2. this issue dealing with this exact need.
First, I'm reading the file directly from the SwaggerUI Choose File input, so no additional headers passed to indicate the file extention or MIME type that could be read by a frontend or an api consumer.
Then, I wanted to set the files directly in the route definition, just as any other dependency. The dependency alone would not work here as it is only called after the whole file is uploaded.
So, my current working solution for my csv and excel files is using a custom BaseHTTPMiddleware
, read the request body asynchronously and get the "headers" from the file itself.
From what I deduced, this gets the first chunk of the body asynchronous generator, and it has the information of the file being uploaded. In order to prevent program stall, the get_body function is implemented as per 1.
import re
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from dependencies import ContentTypeChecker
def get_content_type_from_body(body):
content_type_match = re.search(rb'Content-Type: ([^\r\n]+)', body)
if content_type_match:
content_type = content_type_match.group(1).decode("utf-8")
return content_type
async def set_body(request: Request, body: bytes):
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
async def get_body(request: Request) -> bytes:
body = await request.body()
await set_body(request, body)
return body
class ValidateContentTypeMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
content_type = request.headers.get("Content-Type", "")
file_content_type = ''
if content_type.startswith("multipart/form-data"):
bd = await get_body(request)
file_content_type = get_content_type_from_body(bd)
if file_content_type:
for route in request.app.routes:
try:
for dependency in route.dependant.dependencies:
if not isinstance(dependency.cache_key[0], ContentTypeChecker):
continue
valid_content_type = dependency.call(
content_type=file_content_type)
if not valid_content_type:
exc = HTTPException(
detail=f'File of type {file_content_type} not in {dependency.cache_key[0].content_types}',
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE)
return JSONResponse(status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, content={'message': exc.detail})
except AttributeError as e:
if e.name == 'dependant':
pass
response = await call_next(request)
return response
Then, for it to work, the content type checker is a simple class that is instantiated with the list of allowed content types and a __call__
method, that receives the content type in the middleware
class ContentTypeChecker:
def __init__(self, content_types: List[str]) -> None:
self.content_types = content_types
def __call__(self, content_type: str = ''):
if content_type and content_type not in self.content_types:
return False
return True
One caveat of this approach is that FastAPI will call this again if the content type matches the allowed ones and the middleware fowards the request. Therefore the default value for the content_type on the __call__
method is ''
and returns True whenever FastAPI makes the check itself.
Finally, this is the route definition:
@router.post('/upload',
dependencies=[Depends(ContentTypeChecker(['text/csv']))]
)
async def upload(file: UploadFile = File(...)):
...
I'm not sure if there is a better way of calling the dependant on the validation process.