I'm following this tutorial on training a causal language model from scratch. My dataset is a corpus of text:
my_dataset = ["some_text... _112_ some_text... _113_ some_text... _114_ some_text...", "some_text... _1423_ some_text... _1424_ some_text... _1425_ some_text...", "some_text... _1111_ some_text... _1111_ some_text... _1111_ some_text..."].
The issue is that my dataset contains a clear pattern of numbers in each text (either the numbers are consecutive or they repeat).
I would like to mask out the previous predicted numbers in this pattern as the model predicts the next tokens (note that they always has the pattern of _X_
, where X
is a number, so I don't want to just mask out any previous number, but just those that correspond to the pattern).
For example, given the first text, after the model predicts _112_
, I'd like to mask the number 112
in that sequence for the subsequent token predictions (e.g., "some_text... _MaskToken_ some_text..."
).
I found this SO that I believe asked a similar question a couple of years ago, but left unanswered and used an inefficient method therefore. From the tutorial I'm using it seems like the DataCollatorForLanguageModeling collator might be the way to go about this:
"Besides stacking and padding batches, it also takes care of creating the language model labels — in causal language modeling the inputs serve as labels too (just shifted by one element), and this data collator creates them on the fly during training so we don’t need to duplicate the input_ids."
From this reddit post I understand that the DataCollatorForLanguageModelling
"Duplicate the training sentence. If the masking is performed every time a sequence is fed to the model, the model sees different versions of the same sentence with masks on different positions."
The tutorial also mention
Shifting the inputs and labels to align them happens inside the model, so the data collator just copies the inputs to create the labels.
But going over the source code of GPT2LMHeadModel or the data collator it is not clear to me how to do this either.