I have a model that uses a custom LambdaLayer
as follows:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
The model runs perfectly fine on CPU or 1 GPU. However, when running it with PyTorch Lightning over 2+ GPUs, this error happens:
AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'
The purpose of using a lambda function here is that given an inputs
tensor, I want to pass only inputs[:, start:end]
to the embedding
layer.
My questions:
- is there an alternative to using a lambda in this case?
- if not, what should be done to get the lambda function to work in this context?