I've followed this word level neural language model and everything works right up until the last generate_seq() function. Below is the code to set up the function and the error message I get. I've looked up this error but the solutions I'm seeing don't solve this issue. Ultimately I'm trying to see what additional words the model will product given 50 words fed to it from the sequence text file. Anyone know how to fix the index == yhat
part of for
loop?
from random import randint
from pickle import load
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text
# generate a sequence from a language model
def generate_seq(model, tokenizer, seq_length, seed_text, n_words):
result = list()
in_text = seed_text
# generate a fixed number of words
for _ in range(n_words):
# encode the text as integer
encoded = tokenizer.texts_to_sequences([in_text])[0]
# truncate sequences to a fixed length
encoded = pad_sequences([encoded], maxlen=seq_length, truncating='pre')
# predict probabilities for each word
yhat = model.predict_classes(encoded, verbose=0)
# map predicted word index to word
out_word = ''
for word, index in tokenizer.word_index.items():
if index == yhat:
out_word = word
break
# append to input
in_text += ' ' + out_word
result.append(out_word)
return ' '.join(result)
# load cleaned text sequences
in_filename = 'republic_sequences.txt'
doc = load_doc(in_filename)
lines = doc.split('\n')
seq_length = len(lines[0].split()) - 1
# load the model
model = load_model('model.h5')
# load the tokenizer
tokenizer = load(open('tokenizer.pkl', 'rb'))
# select a seed text
seed_text = lines[randint(0,len(lines))]
print(seed_text + '\n')
# generate new text
generated = generate_seq(model, tokenizer, seq_length, seed_text, 50)
print(generated)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-28-7889eaf7d4eb> in <module>
49
50 # generate new text
---> 51 generated = generate_seq(model, tokenizer, seq_length, seed_text, 50)
52 print(generated)
53
<ipython-input-28-7889eaf7d4eb> in generate_seq(model, tokenizer, seq_length, seed_text, n_words)
24 out_word = ''
25 for word, index in tokenizer.word_index.items():
---> 26 if index == yhat:
27 out_word = word
28 break
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()