I am trying to modify the code in https://keras.io/examples/nlp/lstm_seq2seq/ so it uses GRU instead of LSTM. I have managed to get it to train properly and have constructed the encoder-decoder model for inference using code from Implementing Seq2Seq with GRU in Keras. However, I am getting the error below when running the decode sequence function:
Cell In[19], line 30, in decode_sequence(input_seq)
28 decoded_sentence = ""
29 while not stop_condition:
---> 30 output_tokens, h = decoder_model.predict([target_seq] + states_value, verbose=0)
32 # Sample a token
33 sampled_token_index = np.argmax(output_tokens[0, -1, :])
ValueError: operands could not be broadcast together with shapes (1,1,1,84) (1,2048)
Here is the LSTM and GRU code used to make the model for the fitting process (LSTM parts are commented with 3 #
# Define an input sequence and process it.
encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
# LSTM
###encoder = keras.layers.LSTM(latent_dim, return_state=True)
###encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
###encoder_states = [state_h, state_c]
# GRU
encoder = keras.layers.GRU(latent_dim, return_state=True)
outputs = encoder(encoder_inputs)
encoder_output, encoder_states = outputs[0], outputs[1:]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
# LSTM
###decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
###decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
# GRU
decoder = keras.layers.GRU(latent_dim, return_sequences=True, return_state=True)
outputs = decoder(decoder_inputs, initial_state=tuple(encoder_states))
decoder_outputs, decoder_state = outputs[0], outputs[1:]
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
Here is the LSTM code for the inference process
### LSTM
# Define sampling models
# Restore the model and construct the encoder and decoder.
encoder_model = keras.Model(encoder_inputs, encoder_states)
decoder_state_input_h = keras.Input(shape=(latent_dim,))
decoder_state_input_c = keras.Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq, verbose=0)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index["\t"]] = 1.0
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ""
while not stop_condition:
output_tokens, h, c = decoder_model.predict([target_seq] + states_value, verbose=0)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.0
# Update states
states_value = [h, c]
return decoded_sentence
And here is the code for GRU inference. The part that throws the error is marked with a #
### GRU
# Define sampling models
# Restore the model and construct the encoder and decoder.
encoder_model = keras.Model(encoder_inputs, encoder_states)
decoder_states_inputs = keras.Input(shape=(latent_dim,))
decoder_outputs, decoder_states = decoder(
decoder_inputs, initial_state=decoder_states_inputs)
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model([decoder_outputs] + [decoder_states])
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq, verbose=0)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index["\t"]] = 1.0
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ""
while not stop_condition:
output_tokens, h = decoder_model.predict([target_seq] + states_value, verbose=0) # Error is thrown here
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.0
# Update states
states_value = [h]
return decoded_sentence