4

I am trying to create a scikit-learn Pipeline object with fixed steps, i.e. a PipelineWithFixedSteps(Pipeline) object that inherits from Pipeline, so that I can instantiate it with a simple call PipelineWithFixedSteps() and keep my code clean.

I noticed that if I create several instances of PipelineWithFixedSteps() and I set the parameters of one of them, the parameters of all instances are modified.

Is this an intended behaviour or am I missing something? What could be an alternative way of defining a shortcut for a Pipeline with fixed steps?

I am using sklearn 0.22.1

from sklearn.pipeline import Pipeline
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

class PipelineWithFixedSteps(Pipeline):    
    def __init__(
        self,
        steps = [
            ('scaler', StandardScaler()),
            ('linear', LinearRegression()),
        ]
    ):
        super().__init__(steps=steps)

a = PipelineWithFixedSteps()
print(a.get_params())

a.set_params(scaler__with_std=False)
print(a.get_params())

# Create a new instance of PipelineWithFixedNames()
# The new instance has the same parameters as a
b = PipelineWithFixedSteps()
print(b.get_params())

# Set the parameters of b
# The parameters of a are also changed
b.set_params(scaler__with_mean=False)
print(a.get_params())
alfar
  • 41
  • 2

1 Answers1

3

This really has nothing to do with sklearn but boils down to how default values of parameters are interpreted in Python (cf. e.g. this question), and it sounds like you are trying to do something to the extent of

class PipelineWithFixedSteps(Pipeline):    
    def __init__(self, steps=None):
        if steps is None:
            steps = [('scaler', StandardScaler()), ('linear', LinearRegression())]
        super().__init__(steps=steps)
fuglede
  • 17,388
  • 2
  • 54
  • 99
  • Thanks. I get the point and your solution would work in principle. Unfortunately If I want the PipelineWithFixedSteps() to be fully compatible with sklearn, then the __init__() cannot contain any logic (see https://scikit-learn.org/stable/developers/develop.html). – alfar Feb 14 '20 at 08:46