So I'm trying to use Keras' fit_generator with a custom data generator to feed into an LSTM network.
What works
To illustrate the problem, I have created a toy example trying to predict the next number in a simple ascending sequence, and I use the Keras TimeseriesGenerator to create a Sequence instance:
WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
sampling_rate=1, batch_size=1)
I use a simple LSTM network:
data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)
model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
and train it using the fit_generator
function:
model.fit_generator(generator=data_gen,
steps_per_epoch=32,
epochs=10)
And this trains perfectly, and the model makes predictions as expected.
The problem
Now the problem is, in my non-toy situation I want to process the data coming out from the TimeseriesGenerator before feeding the data into the fit_generator
. As a step towards this, I create a generator function which just wraps the TimeseriesGenerator used previously.
def get_generator(data, targets, window_length = 5, batch_size = 32):
while True:
data_gen = TimeseriesGenerator(data, targets, length=window_length,
sampling_rate=1, batch_size=batch_size)
for i in range(len(data_gen)):
x, y = data_gen[i]
yield x, y
data_gen_custom = get_generator(data, data,
window_length=WINDOW_LENGTH, batch_size=1)
But now the strange thing is that when I train the model as before, but using this generator as the input,
model.fit_generator(generator=data_gen_custom,
steps_per_epoch=32,
epochs=10)
There is no error but the training error is all over the place (jumping up and down instead of consistently going down like it did with the other approach), and the model doesn't learn to make good predictions.
Any ideas what I'm doing wrong with my custom generator approach?