I want to generate poems based on Robert Frost's Poems. I have preprocessed my dataset:
max_sentence_len = max(len(l) for l in corpus_int)
input_seq = np.array(tf.keras.preprocessing.sequence.pad_sequences(corpus_int,padding = 'pre',truncating = 'pre',maxlen = max_sentence_len))
predictors, label = input_seq[:,:-1],input_seq[:,-1]#predictors everything except last, label only last
label = ku.to_categorical(label, num_classes=total_words,dtype='int32')
predictors
array([[ 0, 0, 0, ..., 10, 5, 544],
[ 0, 0, 0, ..., 64, 8, 854],
[ 0, 0, 0, ..., 855, 174, 2],
...,
[ 0, 0, 0, ..., 129, 49, 94],
[ 0, 0, 0, ..., 183, 159, 60],
[ 0, 0, 3, ..., 3, 2157, 4]], dtype=int32)
label
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 1]], dtype=int32)
After that i have built my model using encoder - decoder arhcitecture:
class seq2seq(tf.keras.Model):
def __init__(self,max_sequence_len,total_words):
super(seq2seq,self).__init__()
self.max_sequence_len = max_sequence_len
self.total_words = total_words
self.input_len = self.max_sequence_len - 1
self.total_words = self.total_words
#Encoder
self.enc_embedding = tf.keras.layers.Embedding(input_dim = total_words,output_dim = 300,input_length = max_sentence_len - 1)
self.enc_lstm_1 = tf.keras.layers.LSTM(units = 300, activation = 'tanh')
self.enc_lstm_2 = tf.keras.layers.LSTM(units = 300, activation = 'tanh', return_state = True)
#decoder
self.dec_embedding = tf.keras.layers.Embedding(input_dim = total_words,output_dim = 300,input_length = max_sentence_len - 1)
self.dec_lstm_1 = tf.keras.layers.LSTM(units = 300, activation = 'tanh')
self.dec_lstm_2 = tf.keras.layers.LSTM(units = 300, activation = 'tanh', return_state = True,return_sequences = True)
#Dense layer and output:
self.dense = tf.keras.layers.Dense(total_words, activation='softmax')
def call(self,inputs):
#Encoding
enc_x = self.enc_embedding(inputs)
enc_x = self.enc_lstm_1(enc_x)
enc_outputs, state_h, state_c = self.enc_lstm_2(enc_x)
#Decoding:
dec_x = self.dec_embedding(enc_outputs)
dec_x = self.dec_lstm_1(dec_x,initial_state = [state_h, state_c])
dec_outputs, _, _ = self.enc_lstm_2(dec_x)
output_dense = self.dense(dec_outputs)
return output_dense
model = seq2seq(max_sequence_len = max_sentence_len,total_words = total_words)
model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=0.0001),loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(predictors,label,epochs=5, batch_size=128)
But at the end I get the following error:
ValueError Traceback (most recent call last)
<ipython-input-4-1c349573302d> in <module>()
37 model = seq2seq(max_sequence_len = max_sentence_len,total_words = total_words)
38 model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=0.0001),loss='categorical_crossentropy', metrics=['accuracy'])
---> 39 model.fit(predictors,label,epochs=5, batch_size=128)
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/impl/api.py in wrapper(*args, **kwargs)
235 except Exception as e: # pylint:disable=broad-except
236 if hasattr(e, 'ag_error_metadata'):
--> 237 raise e.ag_error_metadata.to_exception(e)
238 else:
239 raise
ValueError: in converted code:
<ipython-input-4-1c349573302d>:27 call *
enc_outputs, state_h, state_c = self.enc_lstm_2(enc_x)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/recurrent.py:623 __call__
return super(RNN, self).__call__(inputs, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py:812 __call__
self.name)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/input_spec.py:177 assert_input_compatibility
str(x.shape.as_list()))
ValueError: Input 0 of layer lstm_1 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 300]
I understand, that the problem is in the input shape(as it was answered in the post expected ndim=3, found ndim=2).
But I don't know how should I reshape my data for the tensorflow 2.0. Can you help me with the problem?