Background
I'm working with a finetuned Mbart50 model that I need sped up for inferencing because using the HuggingFace model as-is is fairly slow with my current hardware. I wanted to use TorchScript because I couldn't get onnx to export this particular model as it seems it will be supported at a later time (I would be glad to be wrong otherwise).
Convert Transformer to a Pytorch trace:
import torch
""" Model data """
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", torchscript= True)
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tokenizer.src_lang = 'en_XX'
dummy = "To celebrate World Oceans Day, we're swimming through a shoal of jack fish just off the coast of Baja, California, in Cabo Pulmo National Park. This Mexican marine park in the Sea of Cortez is home to the northernmost and oldest coral reef on the west coast of North America, estimated to be about 20,000 years old. Jacks are clearly plentiful here, but divers and snorkelers in Cabo Pulmo can also come across many other species of fish and marine mammals, including several varieties of sharks, whales, dolphins, tortoises, and manta rays."
model.config.forced_bos_token_id=250006
myTokenBatch = tokenizer(dummy, max_length=192, padding='max_length', truncation = True, return_tensors="pt")
torch.jit.save(torch.jit.trace(model, [myTokenBatch.input_ids,myTokenBatch.attention_mask]), "././traced-model/mbart-many.pt")
Inference Step:
import torch
""" Model data """
from transformers import MBart50TokenizerFast
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
model = torch.jit.load('././traced-model/mbart-many.pt')
MAX_LENGTH = 192
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
model.to(device)
model.eval()
tokenizer.src_lang = 'en_XX'
dummy = "To celebrate World Oceans Day, we're swimming through a shoal of jack fish just off the coast of Baja, California, in Cabo Pulmo National Park. This Mexican marine park in the Sea of Cortez is home to the northernmost and oldest coral reef on the west coast of North America, estimated to be about 20,000 years old. Jacks are clearly plentiful here, but divers and snorkelers in Cabo Pulmo can also come across many other species of fish and marine mammals, including several varieties of sharks, whales, dolphins, tortoises, and manta rays."
myTokenBatch = tokenizer(dummy, max_length=192, padding='max_length', truncation = True, return_tensors="pt")
encode, pool , norm = model(myTokenBatch.input_ids,myTokenBatch.attention_mask)
Expected Encoding Output:
These are tokens that can be decoded to words with MBart50TokenizerFast.
tensor([[250004, 717, 176016, 6661, 55609, 7, 10013, 4, 642,
25, 107, 192298, 8305, 10, 15756, 289, 111, 121477,
67155, 1660, 5773, 70, 184085, 111, 118191, 4, 39897,
4, 23, 143740, 21694, 432, 9907, 5227, 5, 3293,
181815, 122084, 9201, 23, 70, 27414, 111, 48892, 169,
83, 5368, 47, 70, 144477, 9022, 840, 18, 136,
10332, 525, 184518, 456, 4240, 98, 70, 65272, 184085,
111, 23924, 21629, 4, 25902, 3674, 47, 186, 1672,
6, 91578, 5369, 10332, 5, 21763, 7, 621, 123019,
32328, 118, 7844, 3688, 4, 1284, 41767, 136, 120379,
2590, 1314, 23, 143740, 21694, 432, 831, 2843, 1380,
36880, 5941, 3789, 114149, 111, 67155, 136, 122084, 21968,
8080, 4, 26719, 40368, 285, 68794, 111, 54524, 1224,
4, 148, 50742, 7, 4, 13111, 19379, 1779, 4,
43807, 125216, 7, 4, 136, 332, 102, 62656, 7,
5, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1]])
Actual Output:
I don't know what this is... print(encode)
(tensor([[[[-9.3383e-02, -2.0395e-01, 4.8226e-03, ..., 1.8068e+00,
1.1528e-01, 7.0406e-02],
[-4.4630e-02, -2.2453e-01, 9.5264e-02, ..., 1.6921e+00,
1.4607e-01, 4.8238e-02],
[-7.8206e-01, 1.2699e-01, 1.6467e+00, ..., -1.7057e+00,
8.7768e-01, 8.2230e-01],
...,
[-1.2145e-02, -2.1855e-03, -6.0966e-03, ..., 2.9296e-02,
2.2141e-03, 3.2074e-02],
[-1.4671e-02, -2.8995e-03, -5.8610e-03, ..., 2.8525e-02,
2.4620e-03, 3.1593e-02],
[-1.5877e-02, -3.5165e-03, -4.8743e-03, ..., 2.8930e-02,
2.9877e-03, 3.3892e-02]]]], grad_fn=<CopyBackwards>))