1

I have the next code:

from fastapi import FastAPI, WebSocket, BackgroundTasks
import uvicorn
import time

app = FastAPI()


def run_model():
    ...
    ## code of the model
    answer = [1, 2, 3]
    ...
    results = {"message": "the model has been excuted succesfully!!", "results": answer}
    return results


@app.post("/execute-model")
async def ping(background_tasks: BackgroundTasks):
    background_tasks.add_task(run_model)
    return {"message": "the model is executing"}


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    while True:
        ## Here I wnat the results of run_model
        await websocket.send_text(1)

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8001)

I need to make a post fecth to /execute-model. This endpoint will excute a run_model function as a background task. I need return the answer to the front when run_model() finish and I thought in use websockets but I don't know how do it. Help please.

Sebastian Jose
  • 301
  • 2
  • 9
  • 2
    Your use of BackgroundTasks makes no sense here. Maybe what you want to do is run your synchronous function in a different thread? For this you can use, e.g., `await asyncio.get_event_loop().run_in_executor(None, run_model)`. See https://stackoverflow.com/questions/43241221/how-can-i-wrap-a-synchronous-function-in-an-async-coroutine or https://stackoverflow.com/questions/54685210/calling-sync-functions-from-async-function – ypnos Sep 13 '22 at 18:01
  • This is not the real function. Is an example, in run_model() will be a long time execution model – Sebastian Jose Sep 13 '22 at 18:03
  • 1
    If you don't want to return as a reply to the endpoint, you can simply use asyncio's create_task() and get a future you can store somewhere else and send a websocket message when it is finished. – ypnos Sep 13 '22 at 18:08
  • the problem is that the response from /execute-model cannot be made to wait 2 minutes. I understand, I don't know how to do that but I'll investigate. Thank you – Sebastian Jose Sep 13 '22 at 18:19
  • I don't have a cut-out solution for you right now but the key thing you have to remember is to use the `run_in_executor` so that the long processing task is done in another thread. Also the task needs to be suited for Python threading (the work should not be done in Python code but e.g. underlying libraries). – ypnos Sep 13 '22 at 19:15
  • Run `execute-model` from a websocket connection, reply as you start the task, await the task, then post on the websocket again? Mapping the post to the correct websocket connection otherwise will be hard. – MatsLindh Sep 13 '22 at 20:16
  • I'm focusing the same problem. The issue is as far as I figured, that websocket.sendXY is async, but background tasks, or fastapi.concurrency -> run_in_threadpool must by sync. At least that's were I got stuck so far... I'll post a reply I I figure it out, lol – Chris Sep 30 '22 at 17:31

1 Answers1

1

I had something similar. Here is how I did it (not saying it's the best or even a good solution, but it's working so far):

The route endpoint:

# client makes a post request, gets saved model immeditely, while a background task is started to process the image
@app.post("/analyse", response_model=schemas.ImageAnalysis , tags=["Image Analysis"])
async def create_image_analysis( 
    img: schemas.ImageAnalysisCreate, 
    background_tasks: BackgroundTasks, 
    db: Session = Depends(get_db),
):
    saved = crud.create_analysis(db=db, img=img)
    background_tasks.add_task(analyse_image,db=db, img=img)

    #model includes a ws_token (some random string) that the client can connect to right away
    return saved

The websocket endpoint:

@app.websocket("/ws/{ws_token}")
async def websocket_endpoint(websocket: WebSocket, ws_token: str):
    #add the websocket to the connections dict (by ws_token)
    await socket_connections.connect(websocket,ws_token=ws_token)
    try:
        while True:
            print(socket_connections)
            await websocket.receive_text() #not really necessary
            
    except WebSocketDisconnect:
        socket_connections.disconnect(websocket,ws_token=ws_token)

The analyse_image function:

#notice - the function is not async, as it does not work with background tasks otherwise!!
def analyse_image (db: Session, img: ImageAnalysis):

    print('analyse_image started')
    for index, round in enumerate(img.rounds):
        
        # some heavy workload etc

        # send update to user
        socket_connections.send_message({
                "status":EstimationStatus.RUNNING,
                "current_step":index+1,
                "total_steps":len(img.rounds)
            }, ws_token=img.ws_token)

    print("analysis finished")

The connection Manager:

import asyncio
from typing import Dict, List
from fastapi import  WebSocket

#notice: active_connections is changed to a dict (key= ws_token), so we know which user listens to which model
class ConnectionManager:
    
    def __init__(self):
        self.active_connections: Dict[str, List[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, ws_token: str):
        await websocket.accept()
        if ws_token in self.active_connections:
             self.active_connections.get(ws_token).append(websocket)
        else:
            self.active_connections.update({ws_token: [websocket]})


    def disconnect(self, websocket: WebSocket, ws_token: str):
        self.active_connections.get(ws_token).remove(websocket)
        if(len(self.active_connections.get(ws_token))==0):
            self.active_connections.pop(ws_token)

    # notice: changed from async to sync as background tasks messes up with async functions
    def send_message(self, data: dict,ws_token: str):
        sockets = self.active_connections.get(ws_token)
        if sockets:
            #notice: socket send is originally async. We have to change it to syncronous code - 
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

            for socket in sockets:
                socket.send_text
                loop.run_until_complete(socket.send_json(data))


socket_connections = ConnectionManager()
Chris
  • 4,238
  • 4
  • 28
  • 49