3

I am new to machine learning so I will apologize in advance if this question is somewhat recurrent, as I haven't been able to find a satisfying answer to my problem. As a pedagogical exercise I have been trying to train an ANN to predict a sine wave. My problem is that although my neural network trains accurately the shape of the sine, it somewhat fails to do so in the validation set and to larger inputs. So I start by feeding my input and output as

x = np.arange(800).reshape(-1,1) / 50
y = np.sin(x)/2

The rest of the code goes as

model = Sequential()
model.add(Dense(20, input_shape=(1,),activation = 'tanh',use_bias = True))
model.add(Dense(20,activation = 'tanh',use_bias = True))
model.add(Dense(1,activation = 'tanh',use_bias = True))
model.compile(loss='mean_squared_error', optimizer=Adam(lr=0.005), metrics=['mean_squared_error'])


history = model.fit(x,y,validation_split=0.2, epochs=2000, batch_size=400, verbose=0)

I then devise a test, which is defined as

x1= np.arange(1600).reshape(-1,1) / 50
y1 = np.sin(x1)/2
prediction = model.predict(x1, verbose=1)

So the problem is that the ANN clearly starts to fail in the validation set and to predict a continuation of the sine wave.

Weird behaviour for validation set:

Weird behaviour for validation set

Failure to predict anything other than the training set:

Prediction

So, what am I doing wrong? Is the ANN incapable of continuing the sine wave? I have tried to fine tune most of the available parameters without success. Most of the FAQs regarding similar issues are due to overfitting, but I haven't been able to solve this.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
M.Π.B
  • 77
  • 6
  • Try some RNN like lstm – Nithin Jan 14 '19 at 14:11
  • 1
    Why should I @nithin? If a simple ANN doesn't work I still need intuition of why not, and if a RNN does so, why does it work? I am sorry about all the questions but my theoretical physics background prompts me to the need to understand more than trial and error. – M.Π.B Jan 14 '19 at 14:24
  • As it is pointed out in some answers, a normal ANN will only work in the domain it is trained. so whaterver you do, the ANN wont be able to grasp the cyclic nature of the data. Even if you think logically, you are only giving it a values in a range and teaching it. how is it supposed to know that it should follow the same pattern for other value not in the range of you have taught? – Nithin Jan 14 '19 at 15:32
  • Hence it needs some pre information, that the data is cyclic. So that even if you give if just one cycles data it is able to predict all other values. So this is the cases where you use RNN (lstm), which inherently handles cyclic data. – Nithin Jan 14 '19 at 15:34

2 Answers2

1

Congratulations on stumbling on one of the fundamental issues of deep learning on your first try :)

What you did is correct, and, indeed, the ANN (in its current form) is incapable of continuing the sine wave.

However, you can see signs of overfitting in your MSE graph, starting around epoch 800, when the validation error starts to increase.

BlackBear
  • 22,411
  • 10
  • 48
  • 86
  • Indeed there are signs of overfitting. I am running many epochs so that I can understand the evolution of the NN. Nevertheless, even if I reduce the number of epochs, it doesn't solve the problem. And I guess that if you are correct, there aren't many options in this framework. – M.Π.B Jan 14 '19 at 14:29
  • 1
    Reducing the number of epochs does not solve overfitting, you should add regularization for that (e.g. dropout, weight decay, smaller batch size, etc.) and/or try a simpler model (fewer layers and/or less parameters). And the problem is not in the framework (keras), but it's rather in the methodology (deep learning) – BlackBear Jan 14 '19 at 14:33
1

As pointed above, your NN is not capable of "grasping" cyclic nature of your data.

You can think of your DNN made of dense layers only as of smarter version of linear regression — the reason to use DNN is to have high-level non-linear abstract features that can be "learned" by network itself, instead of engineering features by hand. On contrary, this features are mostly hard to describe and understand.

So, in general DNNs are good for predicting unknown points "in the middle", more far your x from training set, less accurate prediction will be. Again, in general.

To predict things that are cyclic in nature, you should either use more sophisticated architectures, or pre-process your data, i.e. by understanding "seasonality" or "base frequency".

Slam
  • 8,112
  • 1
  • 36
  • 44