3

I am using the Stateful mode of LSTMs in tf.keras where I need to manually do reset_states when I have processed my sequence data, as described here. It seems that normally people do model.reset_states(), but in my case my LSTM layer is embedded in a much more complex network that includes all kinds of other layers like Dense, Conv, and so forth. My question is, if I just call model.reset_states() on my main model that has an LSTM embedded in it (and only one LSTM), should I be worried about that reset affecting other layers in the model such as the Dense or Conv layers? Would it be better to hunt down the LSTM layer and isolate the reset_states call to just that layer?

Nicolas Gervais
  • 33,817
  • 13
  • 115
  • 143
adamconkey
  • 4,104
  • 5
  • 32
  • 61

2 Answers2

2

TLDR: Layers like LSTM/GRU have weights and states, where layers like Conv/Dense/Embedding have only weights. reset_state() only affects layers with states.

What reset_states() does is that for an LSTM it resets the c_t and h_t outputs in the layer. These are the values you normally obtain by setting LSTM(n, return_state=True).

Embedding, Dense, Conv layers don't have such states in them. So model.reset_states() will not affect those kind of feed forward layers. Just the sequential layers like LSTMs and GRUs.

If you like you can have a look at the source code and verify that this function looks if each layer has a reset_state attribute in it (which feed forward layers don't have).

thushv89
  • 10,865
  • 1
  • 26
  • 39
2

Any layer with a settable stateful attribute is subject to reset_states(); the method iterates over each layer, checks whether it has stateful=True - if so, calls its reset_states() method; see source.

In Keras, all recurrent layers, including ConvLSTM2D, have a settable stateful attribute - I'm not aware of any other. tensorflow.keras, however, has plenty of custom layer implementations that may; you can use code below to check for sure:

def print_statefuls(model):
    for layer in model.layers:
        if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
            print(layer.name, "is stateful")
OverLordGoldDragon
  • 1
  • 9
  • 53
  • 101