1

How can I save the model which is returned from the function regarding the defined method class? I would like to make the same wrapper for many classes similar to (in my case) Rocket class.

The code below produces an error: Can't pickle local object 'sktime_wrapper..SKtimeWrapper'

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

    return SKtimeWrapper


model = sktime_wrapper(Rocket)

with open('model.pkl','wb') as f:
    pickle.dump(model, f)

In case when class is defined as top-level object, the pickle works just fine. The code below works like a charm and saves the model without any problem:

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

class SKtimeWrapper(Rocket):
    def transform(self, X):
        X = from_2d_array_to_nested(X)
        return super().transform(X)

    def fit(self, X, Y):
        X = from_2d_array_to_nested(X)
        return super().fit(X, Y)

model = SKtimeWrapper


with open('model.pkl','wb') as f:
    pickle.dump(model, f)

  • This might help: [How to pickle an instance of a class which is written inside a function?](https://stackoverflow.com/questions/11807004/how-to-pickle-an-instance-of-a-class-which-is-written-inside-a-function) – sj95126 Sep 30 '21 at 17:15
  • Thank you. Your answer redirected me to this [post](https://stackoverflow.com/a/11526524/6274417) – user6274417 Oct 01 '21 at 07:01

1 Answers1

0

Following the answers section, I managed to make it work! I hope somebody finds this useful. The trick is to use __reduce__() function.

Bellow is a working example. Beware that the object must be initialized before saving.

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        PARAM = method_class
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

        def __reduce__(self):
            return (_InitializeParameterized(), (self.PARAM,), self.__dict__)

    return SKtimeWrapper

class _InitializeParameterized(object):
    """
    When called with the param value as the only argument, returns an
    un-initialized instance of the parameterized class. Subsequent __setstate__
    will be called by pickle.
    """
    def __call__(self, method_class):
        # make a simple object which has no complex __init__ (this one will do)
        obj = _InitializeParameterized()
        obj.__class__ = sktime_wrapper(method_class)
        return obj


model = sktime_wrapper(Rocket)()

with open('model.pkl','wb') as f:
    pickle.dump(model, f)