3

I want to extend scikit-learn's ClassificationCriterion class which is defined as a Cython class in an internal module sklearn.tree._criterion. I would like to do that in Python, as normally I don't have access to the pyx/pxd files of sklearn (so I cannot cimport them). However, when I try to extend ClassificationCriterion, I get the error TypeError: __cinit__() takes exactly 2 positional arguments (0 given). The below MWE reproduces the error, and shows that the error occurs after __new__ but before __init__.

Is there any way to extend a Cython class like this?

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._criterion import ClassificationCriterion

class MaxChildPrecision(ClassificationCriterion):
    def __new__(self, *args, **kwargs):
        print('new')
        super().__new__(MaxChildPrecision, *args, **kwargs)

    def __init__(self, *args, **kwargs):
        print('init')
        super(MaxChildPrecision).__init__(*args, **kwargs)

clf = DecisionTreeClassifier(criterion=MaxChildPrecision())
  • Methods need to take self as the first positional argument. But they shouldn't pass it on the super call, which is why it's necessary to capture it outside `*args`. – Daniel Roseman Dec 03 '17 at 22:27
  • @DanielRoseman thanks, I updated the code, but it still gives the same error. –  Dec 03 '17 at 22:30
  • [You need to define `__new__`, not `__cinit__`.](http://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#initialisation-methods-cinit-and-init) – user2357112 Dec 03 '17 at 22:59
  • @user2357112 thanks to you too. I updated the code, but this does not resolve the problem yet. Any other ideas? –  Dec 03 '17 at 23:13
  • You need to define `__new__` *and pass the superclass `__new__` the arguments it expects*. Same with passing the superclass `__init__` the arguments it expects ([although those arguments might not be the ones you expect it to expect](https://stackoverflow.com/questions/41390372/why-cant-i-subclass-tuple-in-python3)). – user2357112 Dec 03 '17 at 23:16

1 Answers1

5

There two issues. Firstly, ClassificationCriterion requires two specific arguments to its constructor that you aren't passing it. You will have to work out what these arguments represent and pass them to the base class.

Secondly, there's a Cython issue. If we look at the description of how to use __cinit__ then we see:

Any arguments passed to the constructor will be passed to both the __cinit__() method and the __init__() method. If you anticipate subclassing your extension type in Python, you may find it useful to give the __cinit__() method * and ** arguments so that it can accept and ignore extra arguments. Otherwise, any Python subclass which has an init() with a different signature will have to override __new__() as well as __init__()

Unfortunately, the writers of sklearn didn't provide * and ** arguments, so you do have to override __new__. Something like this should work:

class MaxChildPrecision(ClassificationCriterion):
    def __init__(self,*args, **kwargs):
        pass

    def __new__(cls,*args,**kwargs):
        # I have NO IDEA if these arguments make sense!
        return super().__new__(cls,n_outputs=5,
                           n_classes=np.ones((2,),dtype=np.int))

I pass the necessary arguments to ClassificationCriterion in __new__ and deal with the rest in __init__ as I see fit. I don't need to call the base class __init__ (because the base class doesn't define __init__).

DavidW
  • 29,336
  • 6
  • 55
  • 86
  • Thank you very much. I hope that this did not cost you too much time, because as you posted it I had *almost* arrived at the same result, but wasn't ready to post an answer yet. It seems that the correct parameters for my case are `n_outputs=1, n_classes=np.array([2], dtype=np.intp)`. When I subclass in Python and override `node_impurity` and `children_impurity`, they don't get called. Am I correct in thinking that is because these methods are `cdef`s and not `cpdef`s? I'm now subclassing in Cython and that seems to work OK. –  Dec 04 '17 at 20:56
  • 1
    Yes - you can't access `cdef` methods in Python. It sounds like subclassing in Cython is the way to go if you need to access them. – DavidW Dec 04 '17 at 20:57
  • (2/2) As a side not / for future reference: I added [a feature request to scikit-learn](https://github.com/scikit-learn/scikit-learn/issues/10251) to make the extension of criteria more straightforward. I will post my code there once it works properly, such that this answer can be the correct answer here and future readers can look there for the scikit-learn peculiarities. –  Dec 04 '17 at 20:58