Please see python code below, I put comments in the code where I felt emphasis on information is required.
import keras
import numpy
def build_model():
model = keras.models.Sequential()
model.add(keras.layers.LSTM(3, input_shape = (3, 1), activation = 'elu'))# Number of LSTM cells in this layer = 3.
return model
def build_data():
inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9]
outputs = [10, 11, 12, 13, 14, 15, 16, 17, 18]
inputs = numpy.array(inputs)
outputs = numpy.array(outputs)
inputs = inputs.reshape(3, 3, 1)# Number of samples = 3, Number of input vectors in each sample = 3, size of each input vector = 3.
outputs = outputs.reshape(3, 3)# Number of target samples = 3, Number of outputs per target sample = 3.
return inputs, outputs
def train():
model = build_model()
model.summary()
model.compile(optimizer= 'adam', loss='mean_absolute_error', metrics=['accuracy'])
x, y = build_data()
model.fit(x, y, batch_size = 1, epochs = 4000)
model.save("LSTM_testModel")
def apply():
model = keras.models.load_model("LSTM_testModel")
input = [[[7], [8], [9]]]
input = numpy.array(input)
print(model.predict(input))
def main():
train()
main()
My understanding is that for each input sample there are 3 input vectors. Each input vector goes to an LSTM cell. i.e. For sample 1, input vector 1 goes to LSTM cell 1, input vector 2 goes to LSTM cell 2 and so on.
Looking at tutorials on the internet, I've seen that the number of LSTM cells is much greater than the number of input vectors e.g. 300 LSTM cells.
So say for example I have 3 input vectors per sample what input goes to the 297 remaining LSTM cells?
I tried compiling the model to have 2 LSTM cells and it still accepted the 3 input vectors per sample, although I had to change the target outputs in the training data to accommodate for this(change the dimensions) . So what happened to the third input vector of each sample...is it ignored?
I believe the above image shows that each input vector (of an arbitrary scenario) is mapped to a specific RNN cell. I may be misinterpreting it. Above image taken from the following URL: http://karpathy.github.io/2015/05/21/rnn-effectiveness/