I want to extend scikit-learn's ClassificationCriterion
class which is defined as a Cython class in an internal module sklearn.tree._criterion
. I would like to do that in Python, as normally I don't have access to the pyx/pxd files of sklearn (so I cannot cimport
them). However, when I try to extend ClassificationCriterion
, I get the error TypeError: __cinit__() takes exactly 2 positional arguments (0 given)
. The below MWE reproduces the error, and shows that the error occurs after __new__
but before __init__
.
Is there any way to extend a Cython class like this?
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._criterion import ClassificationCriterion
class MaxChildPrecision(ClassificationCriterion):
def __new__(self, *args, **kwargs):
print('new')
super().__new__(MaxChildPrecision, *args, **kwargs)
def __init__(self, *args, **kwargs):
print('init')
super(MaxChildPrecision).__init__(*args, **kwargs)
clf = DecisionTreeClassifier(criterion=MaxChildPrecision())