Does anyone have an idea for the following problem? Using the frontend, the user should be able to send a request to retrain a machine learning model. While the retraining is taking place, the user should still be able to get predictions from the "old" model by requesting it via the frontend.
How can I request new predictions while the retraining process takes places?
Please find my approach so far below - at the moment, while the retraining takes place (triggered by 'train_btn'), unfortunately no predictions can be made. Thank you in advance!
main.py (FASTAPI)
from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi.responses import HTMLResponse
import joblib
from sklearn.neural_network import MLPClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import pandas as pd
from datetime import date
import re
app = FastAPI()
socket_manager = SocketManager(app=app)
model = joblib.load('spam_classifier.joblib')
html = ""
with open('index.html', 'r') as f:
html = f.read()
async def preprocessor(text):
text = re.sub('<[^>]*>', '', text)
emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text)
text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
return text
async def classify_message(model, message):
message = await preprocessor(message)
label = model.predict([message])[0]
spam_prob = model.predict_proba([message])
await app.sio.emit('server_antwort02', {'label': label, 'spam_probability': spam_prob[0][1] })
async def trainModel():
print('train model startet')
data = pd.read_csv('spam_data.csv')
# Text Preprocessing
import re # regex library
print('before preprocessor')
def preprocessor(text):
text = re.sub('<[^>]*>', '', text) # Effectively removes HTML markup tags
emoticons = re.findall('(?::|;|=)(?:-)?(?:\)|\(|D|P)', text)
text = re.sub('[\W]+', ' ', text.lower()) + ' '.join(emoticons).replace('-', '')
return text
print('after preprocessor')
# Train, Test Split
print('start train, test split')
from sklearn.model_selection import train_test_split
X = data['Message'].apply(preprocessor)
y = data['Category']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Training a Neural Network Pipeline
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score, classification_report
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_val_score
print('start training a NN Pipeline')
tfidf = TfidfVectorizer(strip_accents=None, lowercase=False,
max_features=700,
ngram_range=(1,1))
neural_net_pipeline = Pipeline([('vectorizer', tfidf),
('nn', MLPClassifier(hidden_layer_sizes=(700, 700)))])
neural_net_pipeline.fit(X_train, y_train)
print ('after fit')
print('model is trained')
#await app.sio.emit('server_antwort01',{'data': 'model has been trained'})
# Saving the Pipeline
from joblib import dump
dump(neural_net_pipeline, 'spam_classifier.joblib')
return model
@app.get("/")
async def get():
return HTMLResponse(html)
@app.sio.on('client_connect_event')
async def handle_client_connect_event(sid, *args, **kwargs):
print('server says: connection successful')
await app.sio.emit('server_antwort01', {'data': 'connection was successful'})
@app.sio.on('client_train_event')
async def handle_client_train_event(sid, *args, **kwargs):
print('Server says: train_event worked')
model = await trainModel()
await app.sio.emit('server_antwort01',{'data': 'model has been trained'})
@app.sio.on('client_predict_event')
async def handle_client_predict_event(sid, data, **kwargs):
inputValue = data['data']
await classify_message(model, inputValue)
index.html (JS)
<!DOCTYPE html>
<html>
<head>
<script src="//code.jquery.com/jquery-3.3.1.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/3.1.3/socket.io.js"
integrity="sha512-2RDFHqfLZW8IhPRvQYmK9bTLfj/hddxGXQAred2wNZGkrKQkLGj8RCkXfRJPHlDerdHHIzTFaahq4s/P4V6Qig=="
crossorigin="anonymous"></script>
<link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css">
</head>
<body>
<h1>Websocket Demo</h1>
<h2> Press below to display something send from Server</h2>
<form id="train" method="post" action="#">
<input type="submit" value="train">
</form>
<form id="predict" method="post" action="#">
<input id="input" type="text">
<input type="submit" value="predict">
</form>
<h3> Log </h3>
<div id="log0"></div>
</body>
<script type="text/javascript">
$(document).ready(function () {
var socket = io('ws://localhost:8000', {
path: '/ws/socket.io'
});
socket.on('connect', function (event) {
socket.emit('client_connect_event', { data: 'User connected' });
});
socket.on('server_antwort02', function (msg) {
console.log('received message from server');
$('#log0').append('<br>' + $('<div/>').text('logs #' + ': ' + msg.label + msg.spam_probability).html());
});
socket.on('server_antwort01', function (msg) {
console.log('received message from server');
$('#log0').append('<br>' + $('<div/>').text('logs #' + ': ' + msg.data).html());
});
$('form#train').submit(function (event) {
socket.emit('client_train_event', { data: 'Start training' });
return false;
});
$('form#predict').submit(function (msg) {
var input = document.getElementById('input');
socket.emit('client_predict_event', { data: input.value });
return false;
});
});
</script>
</html>