0

Does anyone have an idea for the following problem? Using the frontend, the user should be able to send a request to retrain a machine learning model. While the retraining is taking place, the user should still be able to get predictions from the "old" model by requesting it via the frontend.

How can I request new predictions while the retraining process takes places?

Please find my approach so far below - at the moment, while the retraining takes place (triggered by 'train_btn'), unfortunately no predictions can be made. Thank you in advance!

main.py (FASTAPI)

from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi.responses import HTMLResponse
import joblib
from sklearn.neural_network import MLPClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import pandas as pd
from datetime import date
import re

app = FastAPI()
socket_manager = SocketManager(app=app)
model = joblib.load('spam_classifier.joblib')

html = ""
with open('index.html', 'r') as f:
    html = f.read()

   
async def preprocessor(text):
    text = re.sub('<[^>]*>', '', text)
    emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text)
    text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
    return text

async def classify_message(model, message):
    message = await preprocessor(message)
    label = model.predict([message])[0]
    spam_prob = model.predict_proba([message])
    await app.sio.emit('server_antwort02', {'label': label, 'spam_probability': spam_prob[0][1] })
    

async def trainModel():
    print('train model startet')
    data = pd.read_csv('spam_data.csv')

    # Text Preprocessing
    import re # regex library
    print('before preprocessor')
    def preprocessor(text):
        text = re.sub('<[^>]*>', '', text) # Effectively removes HTML markup tags
        emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text)
        text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
        return text
    print('after preprocessor')

    # Train, Test Split
    print('start train, test split')
    from sklearn.model_selection import train_test_split
    X = data['Message'].apply(preprocessor)
    y = data['Category']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    
    # Training a Neural Network Pipeline
    from sklearn.pipeline import Pipeline
    from sklearn.linear_model import LogisticRegression
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics import accuracy_score, classification_report
    from sklearn.neural_network import MLPClassifier
    from sklearn.model_selection import cross_val_score
    print('start training a NN Pipeline')
    
    tfidf = TfidfVectorizer(strip_accents=None, lowercase=False, 
                            max_features=700, 
                            ngram_range=(1,1))
    neural_net_pipeline = Pipeline([('vectorizer', tfidf), 
                                    ('nn', MLPClassifier(hidden_layer_sizes=(700, 700)))])

    neural_net_pipeline.fit(X_train, y_train)
    print ('after fit')
 
    print('model is trained')
    #await app.sio.emit('server_antwort01',{'data': 'model has been trained'})
    # Saving the Pipeline
    from joblib import dump
    dump(neural_net_pipeline, 'spam_classifier.joblib')
    return model  

@app.get("/")
async def get():
    return HTMLResponse(html)

@app.sio.on('client_connect_event')
async def handle_client_connect_event(sid, *args, **kwargs):
    print('server says: connection successful')
    await app.sio.emit('server_antwort01', {'data': 'connection was successful'})    
 

@app.sio.on('client_train_event')
async def handle_client_train_event(sid, *args, **kwargs):
    print('Server says: train_event worked')
    model = await trainModel()
    await app.sio.emit('server_antwort01',{'data': 'model has been trained'})

@app.sio.on('client_predict_event')
async def handle_client_predict_event(sid, data, **kwargs):
    inputValue = data['data']
    await classify_message(model, inputValue)

index.html (JS)

<!DOCTYPE html>
<html>

<head>
  <script src="//code.jquery.com/jquery-3.3.1.min.js"></script>
  <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/3.1.3/socket.io.js"
    integrity="sha512-2RDFHqfLZW8IhPRvQYmK9bTLfj/hddxGXQAred2wNZGkrKQkLGj8RCkXfRJPHlDerdHHIzTFaahq4s/P4V6Qig=="
    crossorigin="anonymous"></script>
  <link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css">
</head>

<body>
  <h1>Websocket Demo</h1>
  <h2> Press below to display something send from Server</h2>
  <form id="train" method="post" action="#">
    <input type="submit" value="train">
  </form>
<form id="predict" method="post" action="#">
    <input id="input" type="text">
    <input type="submit" value="predict">
  </form>
  <h3> Log </h3>
  <div id="log0"></div>  
</body>

<script type="text/javascript">

  $(document).ready(function () {
    
    var socket = io('ws://localhost:8000', {
            path: '/ws/socket.io'
        });

    socket.on('connect', function (event) {
      socket.emit('client_connect_event', { data: 'User connected' });
    });

    socket.on('server_antwort02', function (msg) {
      console.log('received message from server');
      $('#log0').append('<br>' + $('<div/>').text('logs #' + ': ' + msg.label + msg.spam_probability).html());
    });

    socket.on('server_antwort01', function (msg) {
      console.log('received message from server');
      $('#log0').append('<br>' + $('<div/>').text('logs #' + ': ' + msg.data).html());
    });


    $('form#train').submit(function (event) {
      socket.emit('client_train_event', { data: 'Start training' });
      return false;
    });

    $('form#predict').submit(function (msg) {
      var input = document.getElementById('input');
      socket.emit('client_predict_event', { data: input.value });
      return false;
    });
  });
</script>
</html>
yellow
  • 41
  • 4

1 Answers1

0

I haven't worked with fastapi_socketio, but it seems to be similar to FastApi WebSockets. And your problem sounds similar to mine. Maybe the solution I found works for you as well: Make an CPU-bound task asynchronous for FastAPI WebSockets

bechtold
  • 446
  • 3
  • 14