20

I want to greedily search the entire parameter space of my support vector classifier using GridSearchCV. However, some combinations of parameters are forbidden by LinearSVC and throw an exception. In particular, there are mutually exclusive combinations of the dual, penalty, and loss parameters:

For example, this code:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

iris = datasets.load_iris()
parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \
              'loss': ['hinge', 'squared_hinge']}
svc = svm.LinearSVC()
clf = GridSearchCV(svc, parameters)
clf.fit(iris.data, iris.target)

Returns ValueError: Unsupported set of arguments: The combination of penalty='l2' and loss='hinge' are not supported when dual=False, Parameters: penalty='l2', loss='hinge', dual=False

My question is: is it possible to make GridSearchCV skip combinations of parameters which the model forbids? If not, is there an easy way to construct a parameter space which won't violate the rules?

crypdick
  • 16,152
  • 7
  • 51
  • 74
  • This would still be a problem, but a lesser problem, if we could at least suppress the FitFailedWarning statements in this case. I face the same battle where I know some combinations are illegal but the logic (as explained below) to prevent these combinations is way too ugly. – demongolem Jun 05 '20 at 13:04

2 Answers2

28

I solved this problem by passing error_score=0.0 to GridSearchCV:

error_score : ‘raise’ (default) or numeric

Value to assign to the score if an error occurs in estimator fitting. If set to ‘raise’, the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error.

UPDATE: newer versions of sklearn print out a bunch of ConvergenceWarning and FitFailedWarning. I had a hard time surppressing them with contextlib.suppress, but there is a hack around that involving a testing context manager:

from sklearn import svm, datasets 
from sklearn.utils._testing import ignore_warnings 
from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 
from sklearn.model_selection import GridSearchCV 

with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]): 
    iris = datasets.load_iris() 
    parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \ 
                 'loss': ['hinge', 'squared_hinge']} 
    svc = svm.LinearSVC() 
    clf = GridSearchCV(svc, parameters, error_score=0.0) 
    clf.fit(iris.data, iris.target)
crypdick
  • 16,152
  • 7
  • 51
  • 74
5

If you want to completely avoid exploring specific combinations (without waiting to run into errors), you have to construct the grid yourself. GridSearchCV can take a list of dicts, where the grids spanned by each dictionary in the list are explored.

In this case, the conditional logic was not so bad, but it would be really tedious for something more complicated:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from itertools import product

iris = datasets.load_iris()

duals = [True, False]
penaltys = ['l1', 'l2']
losses = ['hinge', 'squared_hinge']
all_params = list(product(duals, penaltys, losses))
filtered_params = [{'dual': [dual], 'penalty' : [penalty], 'loss': [loss]}
                   for dual, penalty, loss in all_params
                   if not (penalty == 'l1' and loss == 'hinge') 
                   and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                  and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]

svc = svm.LinearSVC()
clf = GridSearchCV(svc, filtered_params)
clf.fit(iris.data, iris.target)
crypdick
  • 16,152
  • 7
  • 51
  • 74
  • 2
    I appreciate your effort but this seems like a slightly sketchy solution which would result in alot of verbose for a problem with a big number of restrictions – GRoutar Dec 24 '18 at 20:58
  • 1
    @Khabz agreed, this code is cursed! If there's a bazillion conditionals, one possibility is to programmatically construct the list of conditionals in `filtered_params`, then `str.join(conditionals_list)`, and finally `eval()` the string to do the list comprehension. – crypdick Dec 25 '18 at 21:06