2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder


result = permutation_importance(rf,
                                X_test,
                                y_test,
                                n_repeats=10,
                                random_state=42,
                                n_jobs=2)
sorted_idx = result.importances_mean.argsort()
        

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
           vert=False,
           labels=X_test.columns[sorted_idx])

ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()

In the code above, taken from this example in the documentation, is there a way to plot the top 3 features only instead of all the features?

ddejohn
  • 8,775
  • 3
  • 17
  • 30
user308827
  • 21,227
  • 87
  • 254
  • 417
  • This doesn't really have to do with scikit-learn, and there's too much boilerplate code to get to the point where `rf`, `X_test`, and `y_test` are defined. Even then, none of the code here is important because this question boils down to "how do I get the last `n` elements from a list", which is answered [here](https://stackoverflow.com/questions/646644/how-to-get-last-items-of-a-list-in-python) and probably loads of other places. However, because there's a bounty, the question remains open. – ddejohn Sep 23 '21 at 03:00

1 Answers1

3

argsort "returns the indices that would sort an array," so here sorted_idx contains the feature indices in order of least to most important. Since you just want the 3 most important features, take only the last 3 indices:

sorted_idx = result.importances_mean.argsort()[-3:]
# array([4, 0, 1])

Then the plotting code can remain as is, but now it will only plot the top 3 features:

# unchanged
fig, ax = plt.subplots(figsize=(6, 3))
ax.boxplot(result.importances[sorted_idx].T,
           vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()


Note that if you prefer to leave sorted_idx untouched (e.g., to use the full indices elsewhere in the code),

  • either change sorted_idx to sorted_idx[-3:] inline:

    sorted_idx = result.importances_mean.argsort() # unchanged
    
    ax.boxplot(result.importances[sorted_idx[-3:]].T, # replace sorted_idx with sorted_idx[-3:]
               vert=False, labels=X_test.columns[sorted_idx[-3:]]) # replace sorted_idx with sorted_idx[-3:]
    
  • or store the filtered indices in a separate variable:

    sorted_idx = result.importances_mean.argsort() # unchanged
    top3_idx = sorted_idx[-3:] # store top 3 indices
    
    ax.boxplot(result.importances[top3_idx].T, # replace sorted_idx with top3_idx
               vert=False, labels=X_test.columns[top3_idx]) # replace sorted_idx with top3_idx
    
tdy
  • 36,675
  • 19
  • 86
  • 83