The following code combines cross_validate
with GridSearchCV
to perform a nested cross-validation for an SVC on the iris dataset.
(Modified example of the following documentation page: https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html#sphx-glr-auto-examples-model-selection-plot-nested-cross-validation-iris-py.)
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, cross_validate, KFold
import numpy as np
np.set_printoptions(precision=2)
# Load the dataset
iris = load_iris()
X_iris = iris.data
y_iris = iris.target
# Set up possible values of parameters to optimize over
p_grid = {"C": [1, 10],
"gamma": [.01, .1]}
# We will use a Support Vector Classifier with "rbf" kernel
svm = SVC(kernel="rbf")
# Choose techniques for the inner and outer loop of nested cross-validation
inner_cv = KFold(n_splits=5, shuffle=True, random_state=1)
outer_cv = KFold(n_splits=4, shuffle=True, random_state=1)
# Perform nested cross-validation
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv, iid=False)
clf.fit(X_iris, y_iris)
best_estimator = clf.best_estimator_
cv_dic = cross_validate(clf, X_iris, y_iris, cv=outer_cv, scoring=['accuracy'], return_estimator=False, return_train_score=True)
mean_val_score = cv_dic['test_accuracy'].mean()
print('nested_train_scores: ', cv_dic['train_accuracy'])
print('nested_val_scores: ', cv_dic['test_accuracy'])
print('mean score: {0:.2f}'.format(mean_val_score))
cross_validate
splits the data set in each fold into a training and a test set. In each fold, the input estimator is then trained based on the training set associated with the fold. The inputted estimator here is clf
, a parameterized GridSearchCV
estimator, i.e. an estimator that cross-validates itself again.
I have three questions about the whole thing:
- If
clf
is used as the estimator forcross_validate
, does it (in the course of theGridSearchCV
cross validation) split the above mentioned training set into a subtraining set and a validation set in order to determine the best hyper parameter combination? - Out of all models tested via
GridSearchCV
, doescross_validate
validate only the model stored in thebest_estimator_
attribute? - Does
cross_validate
train a model at all (if so, why?) or is the model stored inbest_estimator_
validated directly via the test set?
To make it clearer how the questions are meant, here is an illustration of how I imagine the double cross validation at the moment.