0

I have created a model with an LSTM layer as shown below and want to get the internal state (hidden state and cell state) after the training step and save it. After the training step, I will use the network for a prediction and want to reinitialize the LSTM with the saved internal state before the next training step. This way I can continue from the same point after each training step. I haven't been able to find something helpful for the current version of tensoflow, i.e 2.x.

 import tensorflow as tf

class LTSMNetwork(object):
    def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
        self.num_channels = num_channels
        self.num_hidden_neurons = num_hidden_neurons
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.batch_size =batch_size

    def lstm_model(self):
        self.model = tf.keras.Sequential()

        self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
                                            units=self.num_hidden_neurons[0], 
                                            activation='tanh', recurrent_activation='sigmoid', 
                                            return_sequences=True, stateful=True))
        #self.model.add(tf.keras.layers.LSTM(units=self.num_hidden_neurons[1], stateful=True))
        hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
        self.model.add(hidden_layer)
        self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
        self.model.compile(optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate), 
                            loss='mse', metrics=['binary_accuracy'])
        return self.model


if __name__=='__main__':

    num_channels = 3
    num_hidden_neurons = [150, 100]
    learning_rate = 0.001
    time_steps = 1
    batch_size = 1

    lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons, 
                                learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
    model = lstm_network.lstm_model()
    model.summary()

2 Answers2

1

You can define a custom Callback and save the hidden and cell states at every epoch for example. Afterwards, you can choose from which epoch you want to extract the states and then use lstm_layer.reset_states(*) to set the initial state again:

import tensorflow as tf

class LTSMNetwork(object):
    def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
        self.num_channels = num_channels
        self.num_hidden_neurons = num_hidden_neurons
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.batch_size =batch_size

    def lstm_model(self):
        self.model = tf.keras.Sequential()

        self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
                                            units=self.num_hidden_neurons[0], 
                                            activation='tanh', recurrent_activation='sigmoid', 
                                            return_sequences=True, stateful=True))
        hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
        self.model.add(hidden_layer)
        self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
        self.model.compile(optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate), 
                            loss='mse', metrics=['binary_accuracy'])
        return self.model


states = {}
class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, lstm_layer):
        self.lstm_layer = lstm_layer
   def on_epoch_end(self, epoch, logs=None):
        states[epoch] = lstm_layer.states

num_channels = 3
num_hidden_neurons = [150, 100]
learning_rate = 0.001
time_steps = 1
batch_size = 1

lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons, 
                            learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
model = lstm_network.lstm_model()
lstm_layer = model.layers[0]
x = tf.random.normal((1, 1, 3))
y = tf.random.normal((1, 1, 3))
model.fit(x, y, epochs=5, callbacks=[CustomCallback(lstm_layer)])
model.summary()
lstm_layer.reset_states(states[0]) # Sets hidden state from first epoch.

States consists of 5 internal states for each of the 5 epochs.

AloneTogether
  • 25,814
  • 5
  • 20
  • 39
  • I am getting the following error. File "C:\Users\sajid\Anaconda3\lib\site-packages\tensorflow\python\util\dispatch.py", line 206, in wrapper return target(*args, **kwargs) File "C:\Users\sajid\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3820, in batch_set_value x.assign(np.asarray(value, dtype=dtype_numpy(x))) File "C:\Users\sajid\Anaconda3\lib\site-packages\numpy\core\_asarray.py", line 83, in asarray return array(a, dtype, copy=False, order=order) TypeError: __array__() takes 1 positional argument but 2 were given – Python_Learner Feb 21 '22 at 10:41
  • This is a numpy error... it has nothing to do with the code you posted and the answer. Did you try answer? It is running without any problems – AloneTogether Feb 21 '22 at 10:41
  • Yes, it is running without any issues. – Python_Learner Feb 21 '22 at 10:47
  • So, it is hard to say where your error is coming from when I showed you that it is running without any problems..can you elaborate on your problem? – AloneTogether Feb 21 '22 at 10:48
  • 1
    The problem is coming from the following line in the code. lstm_layer.reset_states(states[0]) # Sets hidden state from first epoch. – Python_Learner Feb 21 '22 at 10:51
  • if I do save the internal state after training as you mentioned and then ask the network for a prediction, I need the states to be the same after prediction and before the next training step. However, after prediction, internal states change, and the call reset_states(states[0]), sets the new states to the ones after prediction. This way, the network does not start from the same point but from a different internal state. `model.fit(x, y, epochs=1, callbacks=[CustomCallback(lstm_layer)]) print(states[0]) model.predict(x) lstm_layer.reset_states(states[0]) print(states[0])` – Python_Learner Feb 22 '22 at 19:23
  • So what exactly is your question ? :) how to keep the same states before and after prediction? – AloneTogether Feb 22 '22 at 19:30
  • Yes, exactly! I want to keep the states the same before and after the prediction. – Python_Learner Feb 22 '22 at 19:54
  • `lstm_layer.get_initial_state()` shows that everything is set to zero. – Python_Learner Feb 22 '22 at 20:14
  • So the state was reset when calling model.predict? Looking at this: https://stackoverflow.com/questions/39196945/in-keras-when-does-lstm-state-reset-in-the-call-to-model-predict I am not sure if there is a way to prevent that. What happens if you try `model(x)` instead of `model.predict(x)`? – AloneTogether Feb 22 '22 at 20:21
  • Internal states are different before and after both calls, i.e `model(test_data)` and `model.predict(test_data)`. – Python_Learner Feb 22 '22 at 20:26
  • Yeah that’s something hard to change I think.. – AloneTogether Feb 22 '22 at 20:28
0

I have managed to save the internal state of the LSTM after the training step and reinitialize the LSTM with the saved internal states before the next training step. You can create a variable and set its value to the currently stored value in a variable. How can I copy a variable in tensorflow

states_ = {}
# Save the hidden state
internal_state_h = lstm_layer.states[0]
v1 = tf.Variable(initial_value=np.zeros((1, 150)), dtype=tf.float32, shape=(1, 150))
copy_state_h = v1.assign(internal_state_h)

# Save the cell state
internal_state_c = lstm_layer.states[1]
v2 = tf.Variable(initial_value=np.zeros((1, 150)), dtype=tf.float32, shape=(1, 150))
copy_state_c = v2.assign(internal_state_c)

# Create a tuple and add it to the dictionary
states_[0] = (copy_state_h, copy_state_c)

# Reset the internal state
lstm_layer.reset_states(states_[0])

A call for prediction changes the internal states, however by following these steps, you can restore the internal states of RNN to what it was before the prediction.