3

I have a model that uses a custom LambdaLayer as follows:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)


class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

The model runs perfectly fine on CPU or 1 GPU. However, when running it with PyTorch Lightning over 2+ GPUs, this error happens:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

The purpose of using a lambda function here is that given an inputs tensor, I want to pass only inputs[:, start:end] to the embedding layer.

My questions:

  • is there an alternative to using a lambda in this case?
  • if not, what should be done to get the lambda function to work in this context?
Jivan
  • 21,522
  • 15
  • 80
  • 131
  • Of course, a `lambda` function is **never** required. It is *always* possible to use a regular function definition statement, i.e. `lambda : ` is **always** equivalent to `def function(): return ` then `function` – juanpa.arrivillaga Jan 06 '22 at 14:45
  • That being said, you'd have to defined this function at the module level for it to be "pickleable" (which it does basically just by a reference to a name in the global scope). So just put something like `def function(x): return x[:, start:end]` in the module, then in your `__init__` you can do `LambdaLayer(function)` – juanpa.arrivillaga Jan 06 '22 at 14:46
  • @juanpa.arrivillaga I've tried to do just that, using a `def function(x): return x[:, start:end]` but then where do `start` and `end` come from, if `function` is declared at the module level? Both `start` and `end` are arguments to the model's `__init__` method. – Jivan Jan 06 '22 at 14:50
  • 1
    ah, this will be a problem, note, the problem isn't a lambda function per se, it's that pickle only likes to use module-level functions, here, the solution would be to use a higher order function, a function factory that takes those as an input and returns a corresponding function, but you'll be right back to where you started – juanpa.arrivillaga Jan 06 '22 at 14:54
  • Perhaps the best solution here is to define a custom callable object that takes those values as arguments – juanpa.arrivillaga Jan 06 '22 at 14:56
  • @juanpa.arrivillaga I see — are you the person who closed the question? The already-answered question has nothing to do with this one and none of the answers are satisfactory. – Jivan Jan 06 '22 at 14:56
  • 1
    No, it says who closed it, but I re-opened it – juanpa.arrivillaga Jan 06 '22 at 14:57
  • ah ok sorry, thank you for reopening – Jivan Jan 06 '22 at 14:58
  • @juanpa.arrivillaga I closed it. If you look carefully the answer there is very much like your answer, because the question is basically the same: `pytorch` default pickling mechanism have trouble with `lambda` functions and thus these functions needs to be implemented as classes or explicit functions. – Shai Jan 06 '22 at 15:04
  • @Jivan in what way the `Slicer` class proposed here is different than the `LRPolicy` class proposed on the other thread? its fundamentally the same. – Shai Jan 06 '22 at 15:06
  • @Shai ah, sorry, I just took the OP's comment at face value – juanpa.arrivillaga Jan 06 '22 at 15:06
  • My question is also asking what would be an alternative for slicing `inputs` — this part of the question hasn't been answered yet though, but I'd be curious to get an answer on this. – Jivan Jan 06 '22 at 15:07

1 Answers1

3

So the problem isn't the lambda function per se, it's that pickle doesn't work with functions that aren't just module-level functions (the way pickle treats functions is just as references to some module-level name). So, unfortunately, if you need to capture the start and end arguments, you won't be able to use a closure, you'd normally just want something like:

def function_maker(start, end):
    def function(x):
        return x[:, start:end]
    return function

But this will get you right back to where you started, as far as the pickling problem is concerned.

So, try something like:

class Slicer:
    def __init__(self, start, end):
        self.start = start
        self.end = end
    def __call__(self, x):
        return x[:, self.start:self.end])

Then you can use:

LambdaLayer(Slicer(start, end))

I'm not familiar with PyTorch, I'm surprised though that it doesn't offer the ability to use a different serialization backend. The pathos/dill project can pickle arbitrary functions, for example, and is often easier to just use that. But I believe the above should solve the problem.

juanpa.arrivillaga
  • 88,713
  • 10
  • 131
  • 172