2

With the Sequential API

If I create a LSTM with the Sequential API of Keras with the following code:

from keras.models import Sequential
from keras.layers import LSTM

model = Sequential()
model.add(LSTM(2, input_dim=3))

then

model.summary()

returns 48 parameters, which is OK as indicated in this Stack Overflow question.

model.summary() with the Sequential API

Quick details:

input_dim = 3, output_dim = 2
n_params = 4 * output_dim * (output_dim + input_dim + 1) = 4 * 2 * (2 + 3 + 1) = 48

With the Functional API

But if I do the same with the functional API with the following code:

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM

inputs = Input(shape=(3, 1))
lstm = LSTM(2)(inputs)
model = Model(input=inputs, output=lstm)

then

model.summary()

returns 32 parameters.

model.summary() with the Functional API

Why there is such a difference?

today
  • 32,602
  • 8
  • 95
  • 115
Manu NALEPA
  • 1,356
  • 1
  • 14
  • 23

2 Answers2

4

The difference is that when you pass input_dim=x to a RNN layer, including LSTM layers, it means that the input shape is (None, x) i.e. there are varying number of timesteps where each one is a vector of length x. However, in the functional API example, you are specifying shape=(3, 1) as input shape and it means there are 3 timesteps where each has one feature. Therefore the number of parameters would be: 4 * output_dim * (output_dim + input_dim + 1) = 4 * 2 * (2 + 1 + 1) = 32 which is the number shown in the model summary.

Further, if you use Keras 2.x.x, you would get a warning in case of using input_dim argument for a RNN layer:

UserWarning: The input_dim and input_length arguments in recurrent layers are deprecated. Use input_shape instead.

UserWarning: Update your LSTM call to the Keras 2 API: LSTM(2, input_shape=(None, 3))

today
  • 32,602
  • 8
  • 95
  • 115
  • 1
    If, in the Functional API, I replace **inputs = Input(shape=(3, 1))** by **inputs = Input(shape=(1, 3))**, I get **48** parameters, as expected. Thanks! – Manu NALEPA Oct 11 '18 at 12:19
0

I solved it like the following:

Case 1:
m (input) = 3
n (output) = 2

params = 4 * ( (input * output) + (output ^ 2) + output)
       = 4 * (3*2 + 2^2 + 2)
       = 4 * (6 + 4 + 2)
       = 4 * 12
       = 48



Case 2:
m (input) = 1
n (output) = 2

params = 4 * ( (input * output) + (output ^ 2) + output)
       = 4 * (1*2 + 2^2 + 2)
       = 4 * (2 + 4 + 2)
       = 4 * 8
       = 32
Biranchi
  • 16,120
  • 23
  • 124
  • 161