0

I've followed this word level neural language model and everything works right up until the last generate_seq() function. Below is the code to set up the function and the error message I get. I've looked up this error but the solutions I'm seeing don't solve this issue. Ultimately I'm trying to see what additional words the model will product given 50 words fed to it from the sequence text file. Anyone know how to fix the index == yhat part of for loop?

from random import randint
from pickle import load
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences

# load doc into memory
def load_doc(filename):
    # open the file as read only
    file = open(filename, 'r')
    # read all text
    text = file.read()
    # close the file
    file.close()
    return text

# generate a sequence from a language model
def generate_seq(model, tokenizer, seq_length, seed_text, n_words):
    result = list()
    in_text = seed_text
    # generate a fixed number of words
    for _ in range(n_words):
        # encode the text as integer
        encoded = tokenizer.texts_to_sequences([in_text])[0]
        # truncate sequences to a fixed length
        encoded = pad_sequences([encoded], maxlen=seq_length, truncating='pre')
        # predict probabilities for each word
        yhat = model.predict_classes(encoded, verbose=0)
        # map predicted word index to word
        out_word = ''
        for word, index in tokenizer.word_index.items():
            if index == yhat:
                out_word = word
                break
        # append to input
        in_text += ' ' + out_word
        result.append(out_word)
    return ' '.join(result)

# load cleaned text sequences
in_filename = 'republic_sequences.txt'
doc = load_doc(in_filename)
lines = doc.split('\n')
seq_length = len(lines[0].split()) - 1

# load the model
model = load_model('model.h5')

# load the tokenizer
tokenizer = load(open('tokenizer.pkl', 'rb'))

# select a seed text
seed_text = lines[randint(0,len(lines))]
print(seed_text + '\n')

# generate new text
generated = generate_seq(model, tokenizer, seq_length, seed_text, 50)
print(generated)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-28-7889eaf7d4eb> in <module>
     49 
     50 # generate new text
---> 51 generated = generate_seq(model, tokenizer, seq_length, seed_text, 50)
     52 print(generated)
     53 

<ipython-input-28-7889eaf7d4eb> in generate_seq(model, tokenizer, seq_length, seed_text, n_words)
     24                 out_word = ''
     25                 for word, index in tokenizer.word_index.items():
---> 26                         if index == yhat:
     27                                 out_word = word
     28                                 break

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
jr599
  • 1
  • Welcome to Stack Overflow. In your own words, where the code says `if index == yhat`, **what type** do you expect `index` and `yhat` to have? What should `==` do with them, and why? – Karl Knechtel Oct 19 '22 at 21:09
  • Hey Karl - I honestly don't understand yet what's expected in those arrays. Have you tried following the tutorial from the link I provided? – jr599 Oct 19 '22 at 21:18
  • "I honestly don't understand yet what's expected in those arrays." Then please learn Numpy first, before trying to learn Tensorflow or Keras. – Karl Knechtel Oct 19 '22 at 21:19

1 Answers1

0

When you compare an array with an array, the result will also be an array with boolean values. For example:

np.ones((2,)) == np.ones((2,))

will return:

array([ True,  True])

If you want to check if all values in that boolean array are true – meaning the arrays contain equal values – you have to use all:

all(np.ones((2,))==np.ones((2,)))

Output:

True

So in your case, since you are comparing arrays, you should do:

if all(index == yhat):
AndrzejO
  • 1,502
  • 1
  • 9
  • 12
  • Thanks AndrzejO! I tried your suggestion above but I still get the same error as shown in my initial message :( Have you tried running the code with the text as per the webpage I provided? Any other suggestions you may have? – jr599 Oct 19 '22 at 21:17
  • Before the `if`, print out both `index` and `yhat`. If they are arrays print out their shape – AndrzejO Oct 19 '22 at 21:21