4

I'm wondering if there's a way to change the order the features in a SHAP beeswarm plot are displayed in. The docs describe "transforms" like using shap_values.abs or shap_values.abs.mean(0) to change how the ordering is calculated, but what I actually want is to put in a list of features or indices and have it order by that.

From the docs:

shap.plots.beeswarm(shap_values, order=shap_values.abs)

This is the resulting plot

Eich Varkin
  • 43
  • 1
  • 3

1 Answers1

2

This is the default implementation of ordering:

import xgboost
import shap

X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)

explainer = shap.Explainer(model, X)
shap_values = explainer(X)

shap.plots.beeswarm(shap_values, max_display=12, order=shap.Explanation.abs.mean(0))

enter image description here

Then, if you want define ordering of output columns manually:

order = [
    "Country",
    "Workclass",
    "Education-Num",
    "Marital Status",
    "Occupation",
    "Relationship",
    "Race",
    "Sex",
    "Capital Gain",
    "Capital Loss",
    "Hours per week",
    "Age",
]
col2num = {col: i for i, col in enumerate(X.columns)}

order = list(map(col2num.get, order))

shap.plots.beeswarm(shap_values, max_display=12, show=False, color_bar=False, order=order)
plt.colorbar()
plt.show()

enter image description here

Sergey Bushmanov
  • 23,310
  • 7
  • 53
  • 72
  • This is great, thanks! I think you forgot to put in altered code for the second graph, but I tried: `shap.plots.beeswarm(shap_values, max_display=12, show=False, color_bar=False, order=order)` and that works – Eich Varkin May 06 '22 at 12:57