0

I have made my first ANN with Keras. It's a Linear Regression Model with 5 features and 1 output. I made a plot with "MSE" and "Loss function" and these are the results. Can we say that it is a good model? In addition R^2 = 0.91 . Is this the right way?

classifier = Sequential()

classifier.add(Dense(5, input_dim=5,kernel_initializer='normal',activation='relu'))

classifier.add(Dense(5, activation='relu'))

classifier.add(Dense(1,activation='linear'))


classifier.compile(loss='mse', optimizer='adam', metrics=['mse','mae'])

history = classifier.fit(X_train, y_train, batch_size=10, validation_data=(X_test, y_test), epochs=200, verbose=0)

y_pred=classifier.predict(X_test)

train_mse=classifier.evaluate(X_train, y_train, verbose=0)

plt.title('Loss / Mean Squared Error')
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show()

enter image description here

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Apantazo
  • 49
  • 1
  • 7

1 Answers1

2

Apart from some terminology details (NN regression is not linear regression, and usually we don't call such model a classifier), your model looks good indeed, with both errors (training & test) reducing smoothly and no signs of overfitting.

Although an R^2 value of 0.91 sounds pretty good, the use of the metric in predictive settings, like here, is quite problematic; quoting from my own answer in another SO thread:

the whole R-squared concept comes in fact directly from the world of statistics, where the emphasis is on interpretative models, and it has little use in machine learning contexts, where the emphasis is clearly on predictive models; at least AFAIK, and beyond some very introductory courses, I have never (I mean never...) seen a predictive modeling problem where the R-squared is used for any kind of performance assessment; neither it's an accident that popular machine learning introductions, such as Andrew Ng's Machine Learning at Coursera, do not even bother to mention it. And, as noted in the Github thread above (emphasis added):

In particular when using a test set, it's a bit unclear to me what the R^2 means.

with which I certainly concur.

Community
  • 1
  • 1
desertnaut
  • 57,590
  • 26
  • 140
  • 166