42

I am doing experiments on bert architecture and found out that most of the fine-tuning task takes the final hidden layer as text representation and later they pass it to other models for the further downstream task.

Bert's last layer looks like this :

enter image description here

Where we take the [CLS] token of each sentence :

enter image description here

Image source

I went through many discussion on this huggingface issue, datascience forum question, github issue Most of the data scientist gives this explanation :

BERT is bidirectional, the [CLS] is encoded including all representative information of all tokens through the multi-layer encoding procedure. The representation of [CLS] is individual in different sentences.

My question is, Why the author ignored the other information ( each token's vector ) and taking the average, max_pool or other methods to make use of all information rather than using [CLS] token for classification?

How does this [CLS] token help compare to the average of all token vectors?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Aaditya Ura
  • 12,007
  • 7
  • 50
  • 88

2 Answers2

29

The use of the [CLS] token to represent the entire sentence comes from the original BERT paper, section 3:

The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks.

Your intuition is correct that averaging the vectors of all the tokens may produce superior results. In fact, that is exactly what is mentioned in the Huggingface documentation for BertModel:

Returns

pooler_output (torch.FloatTensor: of shape (batch_size, hidden_size)):

Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training.

This output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.

Update: Huggingface removed that statement ("This output is usually not a good summary of the semantic content ...") in v3.1.0. You'll have to ask them why.

stackoverflowuser2010
  • 38,621
  • 48
  • 169
  • 217
  • Maybe by a lot of experiments, that statement was proven wrong? – avocado Aug 20 '21 at 22:14
  • 1
    One dumb question about the [CLS] token: so since every input sequence are all using this same [CLS] token as the first token in the sequence, which means the same embedding vector is shared by all the input sequences, right? So how could we use the final hidden state of this first token for later classification task? I mean since the input embedding of [CLS] token is all shared across all sequences, how much difference could be represented in the final hidden state of the first token? – avocado Aug 20 '21 at 22:56
  • 3
    The embeddings in BERT and other contextual language models are not static. The embedding for CLS (that is, the actual 768 float-point values) will differ depending on the input sequence because it's computed using attention (i.e. a weighted average) over all the input token embeddings. – stackoverflowuser2010 Aug 21 '21 at 00:50
26

BERT is designed primarily for transfer learning, i.e., finetuning on task-specific datasets. If you average the states, every state is averaged with the same weight: including stop words or other stuff that are not relevant for the task. The [CLS] vector gets computed using self-attention (like everything in BERT), so it can only collect the relevant information from the rest of the hidden states. So, in some sense the [CLS] vector is also an average over token vectors, only more cleverly computed, specifically for the tasks that you fine-tune on.

Also, my experience is that when I keep the weights fixed and do not fine-tune BERT, using the token average yields better results.

Jindřich
  • 10,270
  • 2
  • 23
  • 44