1

I am unsure of how to use Decorators properly; I have drawn reference from Real Python and Try-Except for Multiple Methods. I am coding up a Linear Regression class, and I realised that you need to call fit before you can do predict, or other methods that my class have. But it is cumbersome to define each and every method to raise error when the self._fitted flag is False. So I turned to decorators, I am unsure if I am using correctly, because it does behave the way I want it to, however it neglects any other forms of errors like ValueError etc. Asking for advice here.

import functools
from sklearn.exceptions import NotFittedError


def NotFitted(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            raise NotFittedError

    return wrapper

class LinearRegression:
    def __init__(self, fit_intercept: bool = True):
        self.coef_ = None
        self.intercept_ = None
        self.fit_intercept = fit_intercept
        # a flag to turn to true once we called fit on the data
        self._fitted = False

def check_shape(self, X: np.array, y: np.array):
    # if X is 1D array, then it is simple linear regression, reshape to 2D
    # [1,2,3] -> [[1],[2],[3]] to fit the data
    if X is not None and len(X.shape) == 1:
        X = X.reshape(-1, 1)
    # self._features = X
    # self.intercept_ = y
    return X, y

def fit(self, X: np.array = None, y: np.array = None):
    X, y = self.check_shape(X, y)
    n_samples, n_features = X.shape[0], X.shape[1]
    if self.fit_intercept:
   
        X = np.c_[np.ones(n_samples), X]
    XtX = np.dot(X.T, X)
    XtX_inv = np.linalg.inv(XtX)
    XtX_inv_Xt = np.dot(XtX_inv, X.T)
    _optimal_betas = np.dot(XtX_inv_Xt, y)

    # set attributes from None to the optimal ones
    self.coef_ = _optimal_betas[1:]
    self.intercept_ = _optimal_betas[0]
    self._fitted = True

    return self

@NotFitted
def predict(self, X: np.array):
    """
    after calling .fit, you can continue to .predict to get model prediction
    """
    # if self._fitted is False:
    #     raise NotFittedError
    if self.fit_intercept:
        y_hat = self.intercept_ + np.dot(X, self.coef_)
    else:
        y_hat = self.intercept_
    return y_hat
desertnaut
  • 57,590
  • 26
  • 140
  • 166
ilovewt
  • 911
  • 2
  • 10
  • 18
  • If your problem with this approach is that the errors from the wrapped `func` are not propagated, you can change your error handling in the decorator to rethrow the error instead of always using a `NotFittedError`. But I am confused: does your decorator even have any effect other than catching those errors and masking them as `NotFittedError`? I don't think `_fitted` is ever read? – lucidbrot Feb 28 '21 at 15:06
  • 1
    Yes that is probably what I wanted to ask, do I need to call `_fitted` in the decorator? – ilovewt Mar 01 '21 at 08:10

1 Answers1

1

Let me quickly repeat what you want to do, to make sure I'm not misunderstanding that. You would like to have a decorator @NotFitted such that every function you annotate with it will first check whether self._fitted is True and fail with a NotFittedError if it is False instead of executing the function.

By looking at this question you can get an understanding for how you could pass additional arguments to the decorator.
I'm not used to using decorators, so I had to quickly test this to see what's going on there in your code - why you don't need a parameter self for the def wrapper:

>>> def deco1(func):
...   def wrapper(*args, **kwargs):
...     print("Args are {}".format(args))
...   return wrapper

>>> class Foo(object):
...   @deco1
...   def meth(self, a):
...     print("a: "+a)

>>> f = Foo()
>>> f.meth("hello")
Args are (<__main__.Foo object at 0x7f37676a4128>, 'hello')

As you can see here, the first argument the wrapper prints is actually the self. *args simply collects all non-keyword arguments into a tuple, including self, which is the very first argument here. We could be more explicit by doing def wrapper(self, *args, **kwargs) instead if we wanted to (see that linked question).

Do I need to call _fitted in the decorator?

Yes, because self._fitted is how you keep track of whether it is already fitted or not. You can access it through the first element of *args by doing args[0]._fitted. But I would prefer explicitly passing taking self. Either way, you can check inside the wrapper whether self._fitted is True, and fail if it is not. So I define this example:

#!/bin/env/python3
# Declaring my own NotFittedError, because I don't want to
# from sklearn.exceptions import NotFittedError
# just for this small example.

class NotFittedError (Exception):
    pass

def NotFitted ( foo ):
    def wrapper ( self, *args, **kwargs ):
        if not self._fitted:
            raise NotFittedError()
        else:
            foo ( self, *args, **kwargs )

    return wrapper

class Foo() :
    # Set self._fitted to false just to be explicit.
    # The initial value should be False anyway.
    def __init__(self):
        self._fitted = False

    def fit(self):
        self._fitted = True

    @NotFitted
    def predict(self, X):
        # code here that assumes fit was already called
        print ( "Successfully Predicted!" )

And now we can use it. In the following snippet, I imported it as tmp because I had it in a file called tmp.py. You won't have to do that since you have it all in the same file.

>>> import tmp
>>> f = tmp.Foo()
>>> f.predict("a")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/generic/Downloads/tmp.py", line 12, in wrapper
    raise NotFittedError()
tmp.NotFittedError
>>> f.fit()
>>> f.predict("a")
Successfully Predicted!

A few further comments:

  • If your only goal is to raise a NotFittedError, maybe you wouldn't need to do any of this. The sklearn.NotFittedError would be raised anyway, I think.
  • If you want to distinguish between different kinds of errors in that case, it could also be useful for you to know that you can have multiple except clauses
lucidbrot
  • 5,378
  • 3
  • 39
  • 68