I'm using tensorflow with keras to train to a char-RNN using google colabs. I train my model for 10 epochs and save it, using 'model.save()' as shown in the documentation for saving models. Immediately after, I load it again just to check, I try to call model.fit() on the loaded model and I get a "Dimensions must be equal" error using the exact same training set. The training data is in a tensorflow dataset organised in batches as shown in the documentation for tf datasets. Here is a minimal working example:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
X = np.random.randint(0,50,(10000))
seq_len = 150
batch_size = 20
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(seq_len+1,drop_remainder=True)
dataset = dataset.map(lambda x: (x[:-1],x[1:]))
dataset = dataset.shuffle(20).batch(batch_size,drop_remainder=True)
def make_model(vocabulary_size,embedding_dimension,rnn_units,batch_size,stateful):
model = Sequential()
model.add(Embedding(vocabulary_size,embedding_dimension,
batch_input_shape=[batch_size,None]))
model.add(LSTM(rnn_units,return_sequences=True,stateful=stateful))
model.add(Dense(vocabulary_size))
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',metrics=['accuracy'])
model.summary()
return model
vocab_size = 51
emb_dim = 20
rnn_units = 10
model = make_model(vocab_size,emb_dim,rnn_units,batch_size,False)
model.fit(dataset,epochs=10)
model.save('/content/test_model')
model2 = tf.keras.models.load_model('/content/test_model')
model2.fit(dataset,epochs=10)
The first training line, "model.fit()", runs fine but the last line returns the error:
ValueError: Dimensions must be equal, but are 20 and 150 for '{{node Equal}} = Equal[T=DT_INT64, incompatible_shape_error=true](ArgMax, ArgMax_1)' with input shapes: [20], [20,150].
I want to be able to resume training later, as my real dataset is much larger. Therefore, saving only the weights is not an ideal option.
Any advice? Thanks!