1

Hi I'm currently developing a web app whereby my flask backend deals with prediction models and my sveltekit frontend retrieves the predicted data from the backend. As my backend uses data that is constantly updated, I have scheduled a job every week to retrain the data. When my backend starts running, it first trains the model on a separate thread, and retraining also takes place on this separate thread. The main thread handles the api calls. Initially the variable model is set to None, only after it's initialised and trained, then it points to a model. At least that's what im trying to acheive. However, even after the model is trained, the variable model is still found to be None, as my frontend is getting the response: Error: Model not setup, when I call the http://127.0.0.1:5000/predict endpoint. I tried using threading.Lock() to prevent race conditions but I'm still getting the same error. I've reduced my code to be as minimal as possible (with simple print statements). These are my flask backend files:

my app.py:

import threading

from flask import Flask
from flask_cors import CORS

import models
from routes import routes

app = Flask(__name__)
CORS(app)
app.register_blueprint(routes)


def setup_model():
    models.setup()


# set the flag for setting up the model
setup_model_flag = True

if setup_model_flag:
    training_thread = threading.Thread(target=setup_model)
    training_thread.start()

if __name__ == '__main__':
    app.run(debug=False)

my models.py:

# setup
import threading
from datetime import datetime
from time import sleep

from schedule import every
from schedule import repeat
from schedule import run_pending

model_lock = threading.Lock()
model = None


def prepare_data():
    return "preparing data"


class LSTMModel:
    def __init__(self):
        self.is_trained = False

        self.df = prepare_data()

        self.train()

    def train(self):
        print("Training model...")

        self.is_trained = True

    def predict(self, num_months):
        print("Predicting...")


def setup():
    with model_lock:
        print("Initializing...")
        global model
        model = LSTMModel()
        print(f"Model trained on {datetime.now()}")

    # schedule the job to run every sunday
    @repeat(every().sunday)
    def job():
        with model_lock:
            # update model and retrain data
            print("Initializing new model...")
            global model
            # initialize new instance of model
            model = LSTMModel()
            print(f"Model trained on {datetime.now()}")

    while True:
        # print(idle_seconds())
        run_pending()
        sleep(1)

my routes.py:

from flask import Blueprint
from flask import jsonify
from flask import request

from models import model, model_lock

routes = Blueprint('routes', __name__)


@routes.route('/predict', methods=['POST'])
def predict():
    try:
        data = request.get_json()  # Parse JSON data from the request body
        print('Received data:', data)

        input_data = data['value']
        selected_option = data['type']

        # Log a message to indicate that the endpoint is called
        print('Prediction endpoint called. Input:', input_data, 'Type:', selected_option)

    except Exception as e:
        # An error occurred, return an error response with 400 status code
        return jsonify({'error': 'Invalid JSON format'}), 400

    if selected_option == 'Years':
        num_months = int(input_data) * 12
    else:
        num_months = int(input_data)

    with model_lock:
        if model is None:
            # Model is not setup, return an error response with 500 status code
            return jsonify({'error': 'Model not setup'}), 500

        if not model.is_trained:
            # Model is not trained, return an error response with 500 status code
            return jsonify({'error': 'Model not trained'}), 500

        else:
            try:
                # Perform prediction using the model
                prediction = model.predict(num_months)
            except Exception as e:
                # An error occurred during prediction, return an error response with 500 status code
                return jsonify({'error': 'Prediction error: {}'.format(str(e))}), 500

    # Return the prediction as a JSON response
    return jsonify({'prediction': prediction})

Calling the api can be done with postman with the body in this json format:

enter image description here

1 Answers1

1

It looks like you're trying to share model between threads. This is fighting against how Flask is designed and I wouldn't recommend it. If we need to share state between requests/threads we should consider using an external source like a database to store and retrieve that data.

If you want a quick and dirty solution you could also try using a queue or the pickle module to pickle in and out your desired state. Flask caching is another potentially good alternative depending on your use case.

A better explanation than I've given

Matt
  • 1,368
  • 1
  • 26
  • 54