Suppose I have the following class (which is a PyTorch model in this example, but the question applies to any class in Python):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
Suppose I have an instance model = MyModel()
. What I want to do is to modify the forward
function of the instance model
so that before doing the forward, it first does some pre-processing on the input. For example, if we want to clone (or repeat) x
, then we can define the new forward function as follows:
def new_forward(self, x):
# x has shape (batch_size, ..., d)
# we clone it to obtain (2 * batch_size, ..., d)
x = torch.stack([x, x])
x = torch.flatten(x, start_dim=0, end_dim=1)
return self.forward(x)
More precisely, I want model.forward(x)
to do model.new_forward(x)
.
How can I achieve this without modifying the code of MyModel
?
Thank you very much in advance for your help!
Update: I want to apply the pre-processing to all PyTorch models and not just MyModel
. More precisely, I am given a variable model
, which is an instance of torch.nn.Module
but its type can be any subclass of torch.nn.Module
(e.g. MyModel
).