Edit: Please see the update on the bottom
First off I see that this error comes up when I search for it, and I can understand why it happens for those cases. But specifically for my scenario, I am not sure how to verify what exactly is causing this issue.
I am running the calculate_bleu_score
function from https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb.
from torchtext.data.metrics import bleu_score
def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
trgs = []
pred_trgs = []
for datum in data:
src = vars(datum)['src']
trg = vars(datum)['trg']
pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
#cut off <eos> token
pred_trg = pred_trg[:-1]
pred_trgs.append(pred_trg)
trgs.append([trg])
#print(pred_trgs)
#print(trgs)
return bleu_score(pred_trgs, trgs)
And I call the function as :
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)
test_data
is the Bert-tokenized text, and this is the other function it calls, to convert the test_data language into the target language
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 100):
model.eval()
if isinstance(sentence, str):
tokens = [token for token in bert_tokenizer_de.tokenize(sentence)]
else:
tokens = [token for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
This is the error:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-92-0d45e63957ed> in <module>()
1 start = timer()
----> 2 bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)
3 end = timer()
4
5 print(f'BLEU score = {bleu_score*100:.2f}')
6 frames
<ipython-input-91-ab0f3436cea9> in calculate_bleu(data, src_field, trg_field, model, device, max_len)
11 trg = vars(datum)['trg']
12
---> 13 pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
14
15 #cut off <eos> token
<ipython-input-90-e02b424fde25> in translate_sentence(sentence, src_field, trg_field, model, device, max_len)
14
15 with torch.no_grad():
---> 16 enc_src = model.encoder(src_tensor, src_mask)
17
18 trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-81-44116f7135fc> in forward(self, src, src_mask)
39 #pos = [batch size, src len]
40
---> 41 src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
42
43 #src = [batch size, src len, hid dim]
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
158 return F.embedding(
159 input, self.weight, self.padding_idx, self.max_norm,
--> 160 self.norm_type, self.scale_grad_by_freq, self.sparse)
161
162 def extra_repr(self) -> str:
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2041 # remove once script supports set_grad_enabled
2042 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2043 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2044
2045
IndexError: index out of range in self
From the traceback, looks like the error occurs here:
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
So is there a mismatch with the input_dim? I define input_dim as len(SRC.vocab)
Reference: Embedding in pytorch
Update:
I found out that one sentence length in the test set is super long, any idea how I can remove it from the test iterator?
src ['ich', 'still', '##te', 'meinen', 'h', '##unge', '##r', 'nach', 'elter', '##lichem', 'rat', 'mit', 'diesem', 'b', '##uch', 'u', '##ber', 'eine', 'schrift', '##steller', '-', 'und', 'mus', '##iker', '##familie', '.', '[', '"', 'br', '##iefe', 'der', 'famili', '##e', 'von', 'f', '##u', 'lei', '"', ']', 'ich', 'fand', 'mein', 'vor', '##bild', 'in', 'einer', 'unab', '##hang', '##igen', 'fra', '##u', ',', 'wahren', '##d', 'die', 'konf', '##uz', '##ianische', 'tradition', 'geh', '##ors', '##am', 'verlangt', '.', '[', '"', 'ja', '##ne', 'e', '##yr', '##e', '"', ']', 'von', 'diesem', 'b', '##uch', 'habe', 'ich', 'gelernt', 'effiz', '##ient', 'zu', 'sein', '.', '[', '"', 'im', 'du', '##tz', '##end', 'billiger', '"', ']', 'b', '##ucher', 'haben', 'mich', 'dazu', 'inspiriert', ',', 'im', 'aus', '##land', 'zu', 'studieren', '.']
trg ['i', 'satisfied', 'my', 'hunger', 'for', 'parental', 'advice', 'from', 'this', 'book', 'by', 'a', 'family', 'of', 'writers', 'and', 'musicians', '.', '[', '"', 'correspondence', 'in', 'the', 'family', 'of', 'f', '##ou', 'lei', '"', ']', 'i', 'found', 'my', 'role', 'model', 'of', 'an', 'independent', 'woman', 'when', 'con', '##fu', '##cian', 'tradition', 'requires', 'obedience', '.', '[', '"', 'jane', 'eyre', '"', ']', 'and', 'i', 'learned', 'to', 'be', 'efficient', 'from', 'this', 'book', '.', '[', '"', 'cheaper', 'by', 'the', 'dozen', '"', ']', 'and', 'i', 'was', 'inspired', 'to', 'study', 'abroad', 'after', 'reading', 'these', '.']