3

I've seen the following: Using scikit Pipeline for testing models but preprocessing data only once , but this isn't working. I'm using scikit-learn 1.0.2.

Example:

from sklearn.base import BaseEstimator, TransformerMixin

from sklearn.pipeline import Pipeline
from tempfile import mkdtemp
from joblib import Memory
import time
from shutil import rmtree

class Test(BaseEstimator, TransformerMixin):
    def __init__(self, col):
        self.col = col

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        for t in range(5):
            # just to slow it down / check caching.
            print(".")
            time.sleep(1)
        print(self.col)

cachedir = mkdtemp()
memory = Memory(location=cachedir, verbose=10)


pipline = Pipeline(
    [
        ("test", Test(col="this_column")),
    ],
    memory=memory,
)

pipline.fit_transform(None)

Which will display:

.
.
.
.
.
this_column

When calling it a second time I'm expecting it to be cached, and therefore not have to display the five .\n.\n.\n.\n. output prior to this_column.

This isn't happening though, it gives me the output from the for loop with time.sleep.

Why is this happening?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
baxx
  • 3,956
  • 6
  • 37
  • 75

1 Answers1

1

It seems like the last step of the pipeline is not cached. Here is a slightly modified version of your script.

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
import time

class Test(BaseEstimator, TransformerMixin):
    def __init__(self, col):
        self.col = col

    def fit(self, X, y=None):
        print(self.col)
        return self

    def transform(self, X, y=None):
        for t in range(5):
            # just to slow it down / check caching.
            print(".")
            time.sleep(1)
        #print(self.col)
        return X

pipline = Pipeline(
    [
        ("test", Test(col="this_column")),
        ("test2", Test(col="that_column"))
    ],
    memory="tmp/cache",
)

pipline.fit(None)
pipline.fit(None)
pipline.fit(None)

#this_column
#.
#.
#.
#.
#.
#that_column
#that_column
#that_column
Kota Mori
  • 6,510
  • 1
  • 21
  • 25