Following code produces error in the last few lines only. Please see the last lines of code and tell the solution of the error which is related to some tensor issure
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset
load custom datasets
dataset = load_dataset('csv', data_files={
'train': ['train.csv'],
'eval': ['eval.csv']},
cache_dir="./data/"
)
Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
cache_dir="./models/"
)
Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=dataset['train'],
eval_dataset=dataset['eval'],
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"text": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
Train and evaluate
trainer.train()
metrics = trainer.evaluate()
save
trainer.model._save_pretrained(save_directory="./output/")
from setfit import SetFitModel
model = SetFitModel.from_pretrained("./output/", local_files_only=True)
sentiment_dict = {"negative": 0, "positive": 1}
inverse_dict = {value: key for (key, value) in sentiment_dict.items()}
Run inference
text_list = [
"i loved the spiderman movie!",
"pineapple on pizza is the worst",
"what the fuck is this piece",
"good morning, lady boss",
"the product is excellent",
"a piece of rubbish"
]
preds = model(text_list)
'''for i in range(len(text_list)):
print(text_list[i])
print(inverse_dict[preds[i]])
print('\n')'''
The error comes in the following way.
i loved the spiderman movie!
---------------------------------------------------------------------------
**KeyError Traceback (most recent call last)
<ipython-input-14-bf6d34450e7a> in <module>
2 for i in range(len(text_list)):
3 print(text_list[i])
----> 4 print(inverse_dict[preds[i]])
5 print('\n')
KeyError: tensor(1)**
'''