1

I have created a custom dataset and trained on it a custom T5ForConditionalGeneration model that predicts solutions to quadratic equations like this:

Input: "4*x^2 + 4*x + 1" Output: D = 4 ^ 2 - 4 * 4 * 1 4 * 1 4 * 1 4 * 1 4 * 1 4

I need to get accuracy for this model but I get only loss when I use Trainer so I used a custom metric function (I didn't write it but took it from a similar project):

def compute_metrics4token(eval_pred):
    batch_size = 4
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_preds =  ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels =  ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    answer_accuracy = []
    token_accuracy = []
    num_correct, num_total = 0, 0
    num_answer = 0
    number_eq = 0
    for p, l in zip(decoded_preds, decoded_labels):
        text_pred = p.split(' ')
        text_labels = l.split(' ')
        m = min(len(text_pred), len(text_labels))
        if np.array_equal(text_pred, text_labels):
            num_answer += 1
        for i, j in zip(text_pred, text_labels):
            if i == j:
                num_correct += 1
        num_total += len(text_labels)
        number_eq += 1
    token_accuracy = num_correct / num_total
    answer_accuracy = num_answer / number_eq
    result = {'token_acc': token_accuracy, 'answer_acc': answer_accuracy}
    result = {key: value for key, value in result.items()}
    for key, value in result.items():
        wandb.log({key: value})        
    return {k: round(v, 4) for k, v in result.items()}

Problem is that it doesn't work and I don't really understand why and what can I do to get accuracy for my model. I get this error when I use the function:

args = Seq2SeqTrainingArguments(
    output_dir='./',
    num_train_epochs=10,
    overwrite_output_dir = True,
    evaluation_strategy = 'steps',         
    learning_rate = 1e-4,                 
    logging_steps = 100,                    
    eval_steps = 100,                      
    save_steps = 100,
    load_best_model_at_end = True,
    push_to_hub=True, 
    weight_decay = 0.01,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4
    )

trainer = Seq2SeqTrainer(model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=args, 
                  data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics4token)
<ipython-input-48-ff7980f6dd66> in compute_metrics4token(eval_pred)
      4     # predictions = np.argmax(logits[0])
      5     # print(predictions)
----> 6     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
      7     # Replace -100 in the labels as we can't decode them.
      8     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in batch_decode(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3444             `List[str]`: The list of decoded sentences.
   3445         """
-> 3446         return [
   3447             self.decode(
   3448                 seq,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in <listcomp>(.0)
   3445         """
   3446         return [
-> 3447             self.decode(
   3448                 seq,
   3449                 skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3484         token_ids = to_py_obj(token_ids)
   3485 
-> 3486         return self._decode(
   3487             token_ids=token_ids,
   3488             skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py in _decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
    547         if isinstance(token_ids, int):
    548             token_ids = [token_ids]
--> 549         text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
    550 
    551         clean_up_tokenization_spaces = (

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

When I print out predictions I get a tuple:

(array([[[-32.777344, -34.593437, -36.065685, ..., -34.78577 ,
         -34.77546 , -34.061115],
        [-58.633934, -32.23472 , -31.735909, ..., -40.335655,
         -40.28701 , -37.208904],
        [-56.650974, -33.564095, -34.409576, ..., -36.94467 ,
         -43.246735, -37.469246],
        ...,
        [-56.62741 , -24.561722, -34.11228 , ..., -35.34798 ,
         -42.287125, -38.889412],
        [-56.632545, -24.470266, -34.0792  , ..., -35.313175,
         -42.235626, -38.891712],
        [-56.687027, -24.391508, -34.12526 , ..., -35.30828 ,
         -42.204193, -38.88395 ]],

       [[-29.79866 , -32.22621 , -32.689865, ..., -32.106445,
         -31.46681 , -31.706667],
        [-62.101192, -33.327423, -30.900173, ..., -38.046883,
         -42.26345 , -38.97748 ],
        [-54.726807, -29.13115 , -30.294558, ..., -28.370876,
         -41.23722 , -37.91609 ],
        ...,
        [-57.279373, -23.954525, -34.066246, ..., -35.047447,
         -41.599922, -38.489853],
        [-57.31298 , -23.879845, -34.0837  , ..., -35.03614 ,
         -41.557755, -38.530064],
        [-57.39132 , -23.831306, -34.120094, ..., -35.039547,
         -41.525337, -38.55728 ]],

       [[-29.858566, -32.452713, -34.05892 , ..., -33.93065 ,
         -32.109177, -32.874695],
        [-61.375793, -33.656853, -32.95248 , ..., -42.28087 ,
         -42.637173, -39.21142 ],
        [-58.43721 , -32.496166, -36.44046 , ..., -39.33864 ,
         -42.139664, -38.695328],
        ...,
        [-59.654663, -24.117435, -34.266438, ..., -35.734142,
         -40.55384 , -38.467537],
        [-38.54418 , -18.533113, -29.775307, ..., -26.856483,
         -33.07976 , -29.934727],
        [-27.716005, -14.610603, -23.752686, ..., -21.140053,
         -26.855148, -24.429493]],

       ...,

       [[-33.252697, -34.72487 , -36.395184, ..., -36.87368 ,
         -35.207897, -34.468285],
        [-59.911736, -32.730076, -32.622803, ..., -43.382267,
         -42.25615 , -38.35135 ],
        [-54.982887, -31.847572, -32.773827, ..., -38.500675,
         -43.97969 , -37.41088 ],
        ...,
        [-56.896988, -23.213766, -34.04734 , ..., -35.88832 ,
         -42.176086, -38.953568],
        [-56.994152, -23.141619, -34.054848, ..., -35.875816,
         -42.176453, -38.97729 ],
        [-57.076714, -23.05831 , -34.048904, ..., -35.888298,
         -42.165287, -39.020435]],

       [[-30.070187, -32.049232, -34.63928 , ..., -35.02118 ,
         -32.14465 , -32.891876],
        [-61.720093, -32.994057, -32.988144, ..., -42.054638,
         -42.18583 , -38.990112],
        [-57.74364 , -31.431454, -35.969643, ..., -38.593002,
         -42.276768, -38.895355],
        ...,
        [-58.677704, -23.567434, -35.6751  , ..., -36.018696,
         -40.343582, -38.681267],
        [-58.682228, -23.563087, -35.668964, ..., -36.019753,
         -40.336178, -38.67661 ],
        [-58.718002, -23.609531, -35.67758 , ..., -36.001644,
         -40.366055, -38.67864 ]],

       [[-30.320919, -33.430378, -34.84311 , ..., -37.259563,
         -32.59662 , -33.03912 ],
        [-61.275875, -34.824192, -34.07767 , ..., -44.637024,
         -41.718002, -38.974827],
        [-54.49349 , -30.689342, -35.539658, ..., -39.984665,
         -39.87059 , -37.038437],
        ...,
        [-58.939384, -23.831846, -34.525368, ..., -35.930893,
         -40.29633 , -37.637936],
        [-58.95117 , -23.824234, -34.520042, ..., -35.931396,
         -40.297188, -37.636852],
        [-58.966076, -23.795956, -34.519627, ..., -35.901787,
         -40.261116, -37.612514]]], dtype=float32), array([[[-1.43104442e-03, -2.98473001e-01,  9.49775204e-02, ...,
         -1.77978892e-02,  1.79805323e-01,  1.33578405e-01],
        [-2.35560730e-01,  1.53045550e-01,  5.15255742e-02, ...,
         -1.57466665e-01,  3.49459350e-01,  7.28092641e-02],
        [ 1.60562042e-02, -1.40354022e-01,  5.29232398e-02, ...,
         -2.38162443e-01, -7.72500336e-02,  6.80136457e-02],
        ...,
        [ 7.33550191e-02, -3.35853845e-01,  2.25579832e-03, ...,
         -1.93636306e-02,  1.08121082e-01,  5.24416938e-02],
        [ 8.32231194e-02, -3.11688155e-01, -2.13681534e-02, ...,
          3.23344418e-03,  1.08062990e-01,  7.20862746e-02],
        [ 9.58326831e-02, -3.00361574e-01, -3.02627794e-02, ...,
          3.01265554e-03,  1.20107472e-01,  9.56629887e-02]],

       [[-1.16950013e-01, -3.43173921e-01,  1.87818244e-01, ...,
         -2.71256089e-01,  7.42092952e-02,  5.77520356e-02],
        [-1.62564963e-01, -3.87467295e-01,  1.71134964e-01, ...,
         -7.83916116e-02, -3.65173034e-02,  2.08234787e-01],
        [-3.71523261e-01, -8.74521434e-02,  1.39187068e-01, ...,
         -3.08779895e-01,  3.88156146e-01,  9.99216512e-02],
        ...,
        [ 2.14628279e-02, -3.35561454e-01, -3.76663893e-03, ...,
         -1.29795140e-02,  1.44181430e-01,  1.15508482e-01],
        [ 3.47745977e-02, -3.30934107e-01,  1.10013550e-02, ...,
         -1.84394475e-02,  1.52143195e-01,  1.38157398e-01],
        [ 3.02720107e-02, -3.37626845e-01,  1.35379741e-02, ...,
         -3.80427912e-02,  1.50906458e-01,  1.38765752e-01]],

       [[-6.50129542e-02, -2.63762653e-01,  2.16862872e-01, ...,
         -1.66922837e-01,  1.09285273e-01, -6.40013069e-02],
        [-5.23199737e-01, -2.32228413e-01,  1.44963071e-01, ...,
         -1.41557693e-01,  1.90811172e-01, -2.22496167e-01],
        [-2.24985227e-01, -3.69372189e-01,  7.32450858e-02, ...,
          6.57786876e-02,  9.70033705e-02,  7.83021152e-02],
        ...,
        [-1.93579309e-03, -3.92921537e-01, -1.28203649e-02, ...,
         -8.74079913e-02,  1.13596492e-01,  9.25250202e-02],
        [ 4.55581211e-03, -3.65802884e-01, -2.60831695e-02, ...,
         -4.12549600e-02,  1.17429778e-01,  1.05997331e-01],
        [ 2.46201381e-02, -3.47863257e-01, -4.48134281e-02, ...,
         -2.53352951e-02,  1.16753690e-01,  1.36296600e-01]],

       ...,

       [[-6.47678748e-02, -3.45555365e-01,  7.19114989e-02, ...,
         -9.16809738e-02,  2.15520635e-01,  1.01671875e-01],
        [-7.61077851e-02, -1.51827012e-03,  9.52102616e-02, ...,
         -1.39335945e-01,  1.05894208e-01,  3.23191588e-03],
        [-3.24888170e-01, -2.17741728e-03,  5.32661797e-03, ...,
         -2.78430730e-01,  3.59415114e-01,  1.19439401e-01],
        ...,
        [ 6.89201057e-02, -3.63149673e-01,  7.96841756e-02, ...,
         -3.25191446e-04,  1.26513481e-01,  1.36511743e-01],
        [ 8.16355348e-02, -3.54205281e-01,  7.69739375e-02, ...,
         -2.90949806e-03,  1.31863236e-01,  1.56503588e-01],
        [ 8.36645439e-02, -3.38536322e-01,  8.00612345e-02, ...,
         -9.39210225e-03,  1.29102767e-01,  1.64855778e-01]],

       [[-1.63163885e-01, -3.34902078e-01,  1.11728966e-01, ...,
         -1.10363133e-01,  1.19786285e-01, -9.18702483e-02],
        [-3.36889774e-01, -3.34888607e-01,  1.30680993e-01, ...,
          1.22191897e-03,  1.45059675e-01, -1.27688542e-01],
        [-5.92090450e-02, -2.07585752e-01,  2.05589265e-01, ...,
         -6.80094585e-02,  2.11224273e-01,  3.92790437e-01],
        ...,
        [ 4.86238785e-02, -4.19503808e-01, -3.39424387e-02, ...,
         -1.76134892e-02,  1.00283481e-01,  1.38210282e-01],
        [ 5.81516996e-02, -4.04477298e-01, -4.19086292e-02, ...,
         -1.02474755e-02,  1.06062084e-01,  1.59754634e-01],
        [ 6.70261905e-02, -3.86263877e-01, -4.19785343e-02, ...,
          9.05385148e-03,  1.01594023e-01,  1.69663757e-01]],

       [[-1.22184128e-01, -3.67584258e-01,  3.60302597e-01, ...,
         -4.39502299e-02,  1.33717149e-01,  1.53699834e-02],
        [-3.37780178e-01, -4.05100137e-01,  2.02614054e-01, ...,
         -5.41410968e-02,  1.55447468e-01, -9.28792357e-02],
        [ 1.81227952e-01, -2.29236633e-01,  2.40814224e-01, ...,
          1.39913429e-02,  7.61386827e-02,  3.62152725e-01],
        ...,
        [ 1.47830993e-02, -4.26465064e-01, -1.54972840e-02, ...,
          3.74358669e-02,  1.52016997e-01,  1.53155088e-01],
        [ 3.46656404e-02, -4.00052220e-01, -3.53843644e-02, ...,
          2.64652576e-02,  1.62517026e-01,  1.66649833e-01],
        [ 4.50411513e-02, -3.61773074e-01, -5.50217964e-02, ...,
          3.68298292e-02,  1.67936400e-01,  1.76781893e-01]]],
      dtype=float32))

I thought that maybe I need to take argmax from these values but then I still get errors.

If something is unclear I would be happy to provide additional information. Thanks for any help.

EDIT:

I am adding an example of an item in the dataset:

dataset['test'][0:5]

{'text': ['3*x^2 + 9*x + 6 = 0',
'59*x^2 + -59*x + 14 = 0',
'-10*x^2 + 0*x + 0 = 0',
'3*x^2 + 63*x + 330 = 0',
'1*x^2 + -25*x + 156 = 0'],
'label': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3) 
= -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59) 
= 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) = 
-10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) = 
13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0'],
'__index_level_0__': [10803, 14170, 25757, 73733, 25059]}
alvas
  • 115,346
  • 109
  • 446
  • 738
ALiCe P.
  • 231
  • 1
  • 10
  • Is the `compute_metrics` code you're using just calculating accuracy? If so, maybe try https://huggingface.co/spaces/evaluate-metric/seqeval? – alvas May 08 '23 at 13:42
  • Take a look at https://stackoverflow.com/a/75717951/610569 – alvas May 08 '23 at 13:43
  • @alvas I still get the same problem when `argmax` is taken: ` in compute_metrics2(p) 6 def compute_metrics2(p): 7 predictions, labels = p ----> 8 predictions = predictions.argmax(axis=2) 9 # Remove ignored index (special tokens) 10 true_predictions = [ AttributeError: 'tuple' object has no attribute 'argmax'` – ALiCe P. May 08 '23 at 14:23
  • @alvas could you explain a bit more what should I use as `label_list`? – ALiCe P. May 08 '23 at 14:27
  • Quick question, could you give ~5 data samples? That'll be easier to help you with some examples read into the code. – alvas May 08 '23 at 14:50
  • @alvas I edited the question, added a sample of a dataset – ALiCe P. May 08 '23 at 15:39
  • Thanks, take a look at https://www.kaggle.com/alvations/how-to-train-a-t5-seq2seq-model-using-custom-data – alvas May 08 '23 at 19:19
  • Let me know if the answer helps! BTW, where did you find the code to `compute_metrics4token`? It might be easier to ask the blogpost author or code maintainer to correct it there if you want custom metrics computation. – alvas May 08 '23 at 21:15
  • 1
    It seems like I didn't have `predict_with_generate=True` in my `TrainingArguments` which was giving errors with model output. – ALiCe P. May 09 '23 at 08:14

1 Answers1

1

It seems like the task you're trying to achieve is some sort of "translation" task so the most appropriate model is to use the AutoModelForSeq2SeqLM.

And in the case of unspecified sequence, it might be more appropriate to use

  • BLEU / ChrF or newer neural-based metrics for translation
  • ROUGE for summarization

You can take a look at various translation-related metrics on https://www.kaggle.com/code/alvations/huggingface-evaluate-for-mt-evaluations


Treating it as a normal Machine Translation task

To read the data, you'll have to make sure that the model's forward function

  • sees the data point as {"text": [0, 1, 2, ... ], "labels": [0, 9, 8, ...]} in your datasets.Dataset object
  • use the collator to do batch, e.g. DataCollatorForSeq2Seq

And here's a working snippet of how the code (in parts) can be ran: https://www.kaggle.com/alvations/how-to-train-a-t5-seq2seq-model-using-custom-data

Data processing part.

from datasets import Dataset
import evaluate
from transformers import AutoModelForSeq2SeqLM, Trainer, AutoTokenizer, DataCollatorForSeq2Seq

math_data = {'text': ['3*x^2 + 9*x + 6 = 0',
  '59*x^2 + -59*x + 14 = 0',
  '-10*x^2 + 0*x + 0 = 0',
  '3*x^2 + 63*x + 330 = 0',
  '1*x^2 + -25*x + 156 = 0'],
 'target': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3)  = -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
  'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59)  = 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
  'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
  'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) =  -10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
  'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) =  13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0']}

math_data_eval = {'text': ["10 + 9x(x+3y) - 3x^3"], "target": ["10 + 9x^2 + 27xy - 3x^3"]}

ds_train = Dataset.from_dict(math_data)

model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer)
ds_train = ds_train.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512)
)
ds_train = ds_train.map(lambda y: 
    {"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}
)

ds_eval = Dataset.from_dict(math_data_eval)
ds_eval = ds_eval.map(lambda x: tokenizer(x["text"], 
    truncation=True, padding="max_length", max_length=512))
ds_eval = ds_eval.map(lambda y: 
    {"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}
)

Metric definition part.

import numpy as np

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    # Replace -100s used for padding as we can't decode them
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

Trainer setup part.

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,    # set to 500 for full training
    eval_steps=4,     # set to 8000 for full training
    warmup_steps=1,   # set to 2000 for full training
    max_steps=16,     # delete for full training
    # overwrite_output_dir=True,
    save_total_limit=1,
    #fp16=True, 
)


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=ds_train.with_format("torch"),
    eval_dataset=ds_eval.with_format("torch"),
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

That works and good. But why is the output of the model still so bad?

  • Most probably you need to tune some hyperparameter, batch_size, more data, different learning rates or increase no. of max_steps
  • It can also be that your vocab is pretrained for natural language but your data isn't, in that case, I'll suggest to try modifying the tokenizer before training, e.g. How to add new tokens to an existing Huggingface tokenizer?
alvas
  • 115,346
  • 109
  • 446
  • 738