I had something similar. Here is how I did it (not saying it's the best or even a good solution, but it's working so far):
The route endpoint:
# client makes a post request, gets saved model immeditely, while a background task is started to process the image
@app.post("/analyse", response_model=schemas.ImageAnalysis , tags=["Image Analysis"])
async def create_image_analysis(
img: schemas.ImageAnalysisCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
):
saved = crud.create_analysis(db=db, img=img)
background_tasks.add_task(analyse_image,db=db, img=img)
#model includes a ws_token (some random string) that the client can connect to right away
return saved
The websocket endpoint:
@app.websocket("/ws/{ws_token}")
async def websocket_endpoint(websocket: WebSocket, ws_token: str):
#add the websocket to the connections dict (by ws_token)
await socket_connections.connect(websocket,ws_token=ws_token)
try:
while True:
print(socket_connections)
await websocket.receive_text() #not really necessary
except WebSocketDisconnect:
socket_connections.disconnect(websocket,ws_token=ws_token)
The analyse_image function:
#notice - the function is not async, as it does not work with background tasks otherwise!!
def analyse_image (db: Session, img: ImageAnalysis):
print('analyse_image started')
for index, round in enumerate(img.rounds):
# some heavy workload etc
# send update to user
socket_connections.send_message({
"status":EstimationStatus.RUNNING,
"current_step":index+1,
"total_steps":len(img.rounds)
}, ws_token=img.ws_token)
print("analysis finished")
The connection Manager:
import asyncio
from typing import Dict, List
from fastapi import WebSocket
#notice: active_connections is changed to a dict (key= ws_token), so we know which user listens to which model
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, ws_token: str):
await websocket.accept()
if ws_token in self.active_connections:
self.active_connections.get(ws_token).append(websocket)
else:
self.active_connections.update({ws_token: [websocket]})
def disconnect(self, websocket: WebSocket, ws_token: str):
self.active_connections.get(ws_token).remove(websocket)
if(len(self.active_connections.get(ws_token))==0):
self.active_connections.pop(ws_token)
# notice: changed from async to sync as background tasks messes up with async functions
def send_message(self, data: dict,ws_token: str):
sockets = self.active_connections.get(ws_token)
if sockets:
#notice: socket send is originally async. We have to change it to syncronous code -
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for socket in sockets:
socket.send_text
loop.run_until_complete(socket.send_json(data))
socket_connections = ConnectionManager()