8

Why the number of parameters of the GRU layer is 9600?

Shouldn't it be ((16+32)*32 + 32) * 3 * 2 = 9,408 ?

or, rearranging,

32*(16 + 32 + 1)*3*2 = 9408

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(input_dim=4500, output_dim=16, input_length=200),
    tf.keras.layers.Bidirectional(tf.keras.layers.GRU(32)),
    tf.keras.layers.Dense(6, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()

enter image description here

thushv89
  • 10,865
  • 1
  • 26
  • 39
Abid Orucov
  • 93
  • 1
  • 4
  • Which TensorFlow version are you using? When I run the code on `1.14.0` I get 9408. – thushv89 Aug 02 '19 at 03:13
  • I am using 2.0. I get the expected results for LSTMs, but not for GRU – Abid Orucov Aug 02 '19 at 04:18
  • That's pretty interesting. I went through the TensorFlow source for 2.0 but still all the Cells I went through ended up adding up to 9408. But I'll look into this why this is the case. – thushv89 Aug 02 '19 at 05:01
  • 1
    Thanks for the answer! The answer below helped me to figure it out, apparently, it is due to the parameter reset_after. Depending on whether it is set to True or False, the model uses a different number of bias terms. – Abid Orucov Aug 02 '19 at 13:40

2 Answers2

10

The key is that tensorflow will separate biases for input and recurrent kernels when the parameter reset_after=True in GRUCell. You can look at some of the source code in GRUCell as follow:

if self.use_bias:
    if not self.reset_after:
        bias_shape = (3 * self.units,)
    else:
        # separate biases for input and recurrent kernels
        # Note: the shape is intentionally different from CuDNNGRU biases
        # `(2 * 3 * self.units,)`, so that we can distinguish the classes
        # when loading and converting saved weights.
        bias_shape = (2, 3 * self.units)

Taking the reset gate as an example, we generally see the following formulas. enter image description here

But if we set reset_after=True, the actual formula is as follows: enter image description here

As you can see, the default parameter of GRU is reset_after=True in tensorflow2. But the default parameter of GRU is reset_after=False in tensorflow1.x.

So the number of parameters of a GRU layer should be ((16+32)*32 + 32 + 32) * 3 * 2 = 9600 in tensorflow2.

giser_yugang
  • 6,058
  • 4
  • 21
  • 44
  • 2
    Thank you! I tried both True and False for reset_after, it is exactly as you said. Do you know what is the point of adding two separate bias terms? model can always set b_combined = b_input + b_recurrent, so what is the point? ( as far as I understand the only way it can make any difference, is if the same biases are used somehere else in the model for calculations – Abid Orucov Aug 02 '19 at 13:37
  • @AbidOrucov As the code comment says that `we can distinguish the classes when loading and converting saved weights`. For example, we switch weights from just one layer and set to another layer. – giser_yugang Aug 02 '19 at 15:40
  • 1
    This doesn't make a lot of sense to me. `b_input` and `b_recurrent` are just 1-D vectors, with the same dimension as `br`. You can just add the vectors, call the sum `br` and then the two formula's are exactly identical. In other words, out of these 9600 parameters, 3*64=192 are redundant. – MSalters Apr 28 '22 at 12:02
3

I figured out a little bit more about this, as an addition to the accepted answer. What Keras does in GRUCell.call() is:

With reset_after=False (default in TensorFlow 1):

With reset_after=True (default in TensorFlow 2):

After training with reset_after=False, b_xh equals b_hz, b_xr equals b_hrand b_xh equals b_hh, because (I assume) TensorFlow realizes that each of these pairs of vectors can be combined into one single parameter vector - just like the OP pointed out in a comment above. However, with reset_after=True, that's not the case for b_xh and b_hh - they can and will be different, so they can not be combined into one vector, and that's why the total parameter count is higher.

jerha202
  • 187
  • 1
  • 9