0

Following code produces error in the last few lines only. Please see the last lines of code and tell the solution of the error which is related to some tensor issure

from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset

load custom datasets

dataset = load_dataset('csv', data_files={
    'train': ['train.csv'],
    'eval': ['eval.csv']},
    cache_dir="./data/"
)

Load a SetFit model from Hub

model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    cache_dir="./models/"
)

Create trainer

trainer = SetFitTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['eval'],
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20,  # The number of text pairs to generate for contrastive learning
    num_epochs=1,  # The number of epochs to use for contrastive learning
    column_mapping={"text": "text", "label": "label"}  # Map dataset columns to text/label expected by trainer
)

Train and evaluate

trainer.train()
metrics = trainer.evaluate()

save

trainer.model._save_pretrained(save_directory="./output/")

from setfit import SetFitModel

model = SetFitModel.from_pretrained("./output/", local_files_only=True)

sentiment_dict = {"negative": 0, "positive": 1}
inverse_dict = {value: key for (key, value) in sentiment_dict.items()}

Run inference

text_list = [
    "i loved the spiderman movie!",
    "pineapple on pizza is the worst",
    "what the fuck is this piece",
    "good morning, lady boss",
    "the product is excellent",
    "a piece of rubbish"
]

preds = model(text_list)

'''for i in range(len(text_list)):
    print(text_list[i])
    print(inverse_dict[preds[i]])
    print('\n')'''

The error comes in the following way.

i loved the spiderman movie!
---------------------------------------------------------------------------
**KeyError                                  Traceback (most recent call last)
<ipython-input-14-bf6d34450e7a> in <module>
      2 for i in range(len(text_list)):
      3     print(text_list[i])
----> 4     print(inverse_dict[preds[i]])
      5     print('\n')
KeyError: tensor(1)**
'''

0 Answers0