This is exactly how the BERT model was trained: mask some random words in the sentence, and make your network predict these words. So yes, it is feasible. And not, it is not necessary to have the list of suggested words as a training input. However, these suggested words should be the part of the overall vocabulary with which this BERT has been trained.
I adapted this answer to show how the completion function may work.
# install this package to obtain the pretrained model
# ! pip install -U pytorch-pretrained-bert
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval(); # turning off the dropout
def fill_the_gaps(text):
text = '[CLS] ' + text + ' [SEP]'
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0] * len(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
results = []
for i, t in enumerate(tokenized_text):
if t == '[MASK]':
predicted_index = torch.argmax(predictions[0, i]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
results.append(predicted_token)
return results
print(fill_the_gaps(text = 'I bought an [MASK] because its rainy .'))
print(fill_the_gaps(text = 'Im sad because you are [MASK] .'))
print(fill_the_gaps(text = 'Im worried because you are [MASK] .'))
print(fill_the_gaps(text = 'Im [MASK] because you are [MASK] .'))
The [MASK]
symbol indicates the missing words (there can be any number of them). [CLS]
and [SEP]
are BERT-specific special tokens. The outputs for these particular prints are
['umbrella']
['here']
['worried']
['here', 'here']
The duplication is not surprising - transformer NNs are generally good at copying words. And from semantic point of view, these symmetric continuations look indeed very likely.
Moreover, if it is not a random word which is missing, but exactly the last word (or last several words), you can utilize any language model (e.g. another famous SOTA language model, GPT-2) to complete the sentence.