Is there any way to restrict the vocabulary of the decoder in a Huggingface BERT encoder-decoder model? I'd like to force the decoder to choose from a small vocabulary when generating text rather than BERT's entire ~30k vocabulary.
1 Answers
The generate
method has a bad_words_ids
attribute where you can provide a list of token IDs that you don't want to have in the output.
If you want to just decrease the probabilities of some tokens being generated, you can try to manipulate the bias parameter in the output layer of the model. If your model is based on BERT, you will find the last layer the BertLMPredictionHead
of the decoder. Assuming your seq2seq model is in variable model
, you can access the bias via model.decoder.cls.predictions.decoder.bias
and decrease the bias for token IDs that you would like to appear with a lower probability.
Note also that if you initialize the seq2seq model with BERT (as in the example in the Huggingface Transformer documentation), you need to fine-tune the model heavily because the cross-attention is initialized randomly and the decoder part need to adapt for left-right generation.

- 10,270
- 2
- 23
- 44
-
Thanks, that's helpful. Sounds like `bad_words_ids` might be the most appropriate for me. Yes, I am fine tuning BERT! – Joseph Harvey Oct 07 '21 at 14:31
-
@JosephHarvey did it work for you? I also need to limit the vocab of the decoder – Brans Jan 12 '23 at 10:35