7

Given a Zero-Shot Classification Task via Huggingface as follows:

from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

example_text = "This is an example text about snowflakes in the summer"
labels = ["weather", "sports", "computer industry"]
        
output = classifier(example_text, labels, multi_label=True)
output 
{'sequence': 'This is an example text about snowflakes in the summer',
'labels': ['weather', 'sports'],
'scores': [0.9780895709991455, 0.021910419687628746]}

I am trying to extract the SHAP values to generate a text-based explanation for the prediction result like shown here: SHAP for Transformers

I already tried the following based on the above url:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, ZeroShotClassificationPipeline

model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

pipe = ZeroShotClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)

def score_and_visualize(text):
    prediction = pipe([text])
    print(prediction[0])

    explainer = shap.Explainer(pipe)
    shap_values = explainer([text])

    shap.plots.text(shap_values)

score_and_visualize(example_text)

Any suggestions? Thanks for your help in advance!

Alternatively to the above pipeline the following also works:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, ZeroShotClassificationPipeline

model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

classifier = ZeroShotClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)

example_text = "This is an example text about snowflakes in the summer"
labels = ["weather", "sports"]

output = classifier(example_text, labels)
output 
{'sequence': 'This is an example text about snowflakes in the summer',
'labels': ['weather', 'sports'],
'scores': [0.9780895709991455, 0.021910419687628746]}
Pete
  • 100
  • 4
  • 15

2 Answers2

5

The ZeroShotClassificationPipeline is currently not supported by shap, but you can use a workaround. The workaround is required because:

  1. The shap Explainer forwards only one parameter to the model (a pipeline in this case), but the ZeroShotClassificationPipeline requires two parameters, namely text, and labels.
  2. The shap Explainer will access the config of your model and use its label2id and id2label properties. They do not match the labels returned from the ZeroShotClassificationPipeline and will result in an error.

Below is a suggestion for one possible workaround. I recommend opening an issue at shap and requesting official support for huggingface's ZeroShotClassificationPipeline.

import shap
from transformers import AutoModelForSequenceClassification, AutoTokenizer, ZeroShotClassificationPipeline
from typing import Union, List

weights = "valhalla/distilbart-mnli-12-3"

model = AutoModelForSequenceClassification.from_pretrained(weights)
tokenizer = AutoTokenizer.from_pretrained(weights)

# Create your own pipeline that only requires the text parameter 
# for the __call__ method and provides a method to set the labels
class MyZeroShotClassificationPipeline(ZeroShotClassificationPipeline):
    # Overwrite the __call__ method
    def __call__(self, *args):
      o = super().__call__(args[0], self.workaround_labels)[0]

      return [[{"label":x[0], "score": x[1]}  for x in zip(o["labels"], o["scores"])]]

    def set_labels_workaround(self, labels: Union[str,List[str]]):
      self.workaround_labels = labels

example_text = "This is an example text about snowflakes in the summer"
labels = ["weather","sports"]

# In the following, we address issue 2.
model.config.label2id.update({v:k for k,v in enumerate(labels)})
model.config.id2label.update({k:v for k,v in enumerate(labels)})

pipe = MyZeroShotClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
pipe.set_labels_workaround(labels)

def score_and_visualize(text):
    prediction = pipe([text])
    print(prediction[0])

    explainer = shap.Explainer(pipe)
    shap_values = explainer([text])

    shap.plots.text(shap_values)


score_and_visualize(example_text)

Output: shap output

cronoik
  • 15,434
  • 3
  • 40
  • 78
  • 1
    You're throwing away original `{'contradiction': 0, 'entailment': 2, 'neutral': 1}` and substituting with the arbitrary desired labels. Can you explain how this is going to work at the model level? – Sergey Bushmanov Oct 22 '21 at 21:23
  • The underlying model was trained to predict 3 classes. Are you saying you can arbitrary change the number and meaning of the labels without retraining the model? – Sergey Bushmanov Oct 22 '21 at 21:51
  • No, you can't. Throwing away the original labels was a copy&paste mistake by me. The `ZeroShotClassificationPipeline` requires the `entailment` label. I have corrected my answer. Thanks for your comment. @SergeyBushmanov – cronoik Oct 22 '21 at 21:54
  • Still not very convincing. A sentence/label pair is a `premise/hypothesis` in their parlor. It's not clear at all one can pass hypothesis as a pretrained labelid. – Sergey Bushmanov Oct 22 '21 at 21:58
  • The `ZeroShotClassificationPipeline` creates the `premise/hypothesis`. It will pass the following sentence `"[CLS] This is an example text about snowflakes in the summer" [SEP] This example is sports. [SEP]".` to the model after tokenization and uses the `entailment` logits for its prediction. That's why it is called zero shot. @SergeyBushmanov – cronoik Oct 22 '21 at 22:05
  • The reasonong why `ZeroShotClassificationPipeline` won't pass through `TransformersPipeline` is correct. Why your suggested solution is a correct one needs to be explained. – Sergey Bushmanov Oct 22 '21 at 22:05
  • Let us continue this discussion in [chat](https://chat.stackoverflow.com/rooms/238443/zeroshotclassificationdiscussion). – cronoik Oct 22 '21 at 22:09
3

This is a follow up to the discussion with @cronoik, which could be useful for others in understanding why the magic of tinkering with label2id is going to work.

The docs for ZeroShotClassificationPipeline state:

NLI-based zero-shot classification pipeline using a ModelForSequenceClassification trained on NLI (natural language inference) tasks.

Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis pair and passed to the pretrained model. Then, the logit for entailment is taken as the logit for the candidate label being valid. Any NLI model can be used, but the id of the entailment label must be included in the model config's ~transformers.PretrainedConfig.label2id.

Which means (see the accompanying source code):

  • labels supplied through the __call__ method will be passed to the underlying trained model (via label2id) and will be tried in the premise/entailment sentence pairs
  • in case you overwrite label2id manually, entailment label should be added to label2id (you'll get a warning otherwise). There is no need to add anything else.

As soon as these conditions are met, the model will return dictionaries for the provided labels with sigmoid/softmax logits of entailment in classification like

"<cls> sequence to classify <sep> This example is {label} . <sep>"

as entailment probabilities of the label.

For this type of classifier pipeline label2id's are simply used as a placeholder to keep labels and pass them to other parts of the pipeline.

Sergey Bushmanov
  • 23,310
  • 7
  • 53
  • 72