11

Google's BERT is pretrained on next sentence prediction tasks, but I'm wondering if it's possible to call the next sentence prediction function on new data.

The idea is: given sentence A and given sentence B, I want a probabilistic label for whether or not sentence B follows sentence A. BERT is pretrained on a huge set of data, so I was hoping to use this next sentence prediction on new sentence data. I can't seem to figure out if this next sentence prediction function can be called and if so, how. Thanks for your help!

Paul
  • 121
  • 1
  • 1
  • 4

2 Answers2

19

The answer by Aerin is out-dated. The HuggingFace library (now called transformers) has changed a lot over the last couple of months. Here is an example of how to use the next sentence prediction (NSP) model, and how to extract probabilities from it. NOTE this will only work well if you use a model that has a pretrained head for the NSP task.

from torch.nn.functional import softmax
from transformers import BertForNextSentencePrediction, BertTokenizer


seq_A = 'I like cookies !'
seq_B = 'Do you like them ?'

# load pretrained model and a pretrained tokenizer
model = BertForNextSentencePrediction.from_pretrained('bert-base-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# encode the two sequences. Particularly, make clear that they must be 
# encoded as "one" input to the model by using 'seq_B' as the 'text_pair'
encoded = tokenizer.encode_plus(seq_A, text_pair=seq_B, return_tensors='pt')
print(encoded)
# {'input_ids': tensor([[  101,   146,  1176, 18621,   106,   102,  2091,  1128,  1176,  1172, 136,   102]]),
#  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]),
#  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
# NOTE how the token_type_ids are 0 for all tokens in seq_A and 1 for seq_B, 
# this way the model knows which token belongs to which sequence

# a model's output is a tuple, we only need the output tensor containing
# the relationships which is the first item in the tuple
seq_relationship_logits = model(**encoded)[0]

# we still need softmax to convert the logits into probabilities
# index 0: sequence B is a continuation of sequence A
# index 1: sequence B is a random sequence
probs = softmax(seq_relationship_logits, dim=1)

print(seq_relationship_logits)
print(probs)
# tensor([[9.9993e-01, 6.7607e-05]], grad_fn=<SoftmaxBackward>)
# very high value for index 0: high probability of seq_B being a continuation of seq_A
Bram Vanroy
  • 27,032
  • 24
  • 137
  • 239
  • In the Huggingface documentation (https://huggingface.co/transformers/model_doc/bert.html#bertfornextsentenceprediction) for `BertForNextSentencePrediction`, they provide an example where `model()` is given only one sentence ("Hello, my dog is cute"). Do you know what they give only one sentence instead of two? – stackoverflowuser2010 Apr 27 '20 at 04:18
  • Do you mean why? Probably to simplify the example a bit. You would not do this in training, where you want to use batches. – Bram Vanroy Apr 27 '20 at 09:02
  • 1
    Yes, I meant `why`. Thank you. For `BertForNextSentencePrediction`, at either training time or prediction time, isn't it the case that it always takes two sentences? – stackoverflowuser2010 Apr 27 '20 at 18:17
  • 6
    This seems to give high scores for almost any sentence in seq_B. E.g. I tried out `seq_B = 'blah blah blah'` and `seq_A` remains 'I like cookies !' ... the model still output a very high value at index 0 of the `probs` tensor: tensor([[0.9704, 0.0296]] – AruniRC Jun 18 '20 at 20:00
  • 1
    hm, it might have changed. The example for `BertForNextSentencePrediction` has 2 sentences now. – NatalieL Aug 09 '20 at 18:14
  • @AruniRC This will only work well if the model is trained on this specific task and those weights are available. – Bram Vanroy May 23 '22 at 16:44
  • @BramVanroy wrt your last comment, shouldn't it be the case for `'bert-base-cased'` to be trained on NSP and to have weights available? – amiola Mar 01 '23 at 22:50
  • 1
    @amiola If I recall correctly, the weights of the NSP classification head or not available and were never made available. But I guess that is easy to test for yourself! – Bram Vanroy Mar 02 '23 at 08:44
10

Hugging face did it for you: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L854

class BertForNextSentencePrediction(BertPreTrainedModel):
    """BERT model with next sentence prediction head.
    This module comprises the BERT model followed by the next sentence classification head.
    Params:
        config: a BertConfig class instance with the configuration to build a new model.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
            with indices selected in [0, 1].
            0 => next sentence is the continuation, 1 => next sentence is a random sentence.
    Outputs:
        if `next_sentence_label` is not `None`:
            Outputs the total_loss which is the sum of the masked language modeling loss and the next
            sentence classification loss.
        if `next_sentence_label` is `None`:
            Outputs the next sentence classification logits of shape [batch_size, 2].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    model = BertForNextSentencePrediction(config)
    seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config):
        super(BertForNextSentencePrediction, self).__init__(config)
        self.bert = BertModel(config)
        self.cls = BertOnlyNSPHead(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                     output_all_encoded_layers=False)
        seq_relationship_score = self.cls( pooled_output)

        if next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            return next_sentence_loss
        else:
            return seq_relationship_score
aerin
  • 20,607
  • 28
  • 102
  • 140