0

Suppose I have the following class (which is a PyTorch model in this example, but the question applies to any class in Python):

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    
    def forward(self, x):
        return self.linear(x)

Suppose I have an instance model = MyModel(). What I want to do is to modify the forward function of the instance model so that before doing the forward, it first does some pre-processing on the input. For example, if we want to clone (or repeat) x, then we can define the new forward function as follows:

def new_forward(self, x):
    # x has shape (batch_size, ..., d)
    # we clone it to obtain (2 * batch_size, ..., d)
    x = torch.stack([x, x])
    x = torch.flatten(x, start_dim=0, end_dim=1)
    return self.forward(x)

More precisely, I want model.forward(x) to do model.new_forward(x).

How can I achieve this without modifying the code of MyModel?

Thank you very much in advance for your help!

Update: I want to apply the pre-processing to all PyTorch models and not just MyModel. More precisely, I am given a variable model, which is an instance of torch.nn.Module but its type can be any subclass of torch.nn.Module (e.g. MyModel).

f10w
  • 1,524
  • 4
  • 24
  • 39
  • Do you want to do this because `model = MyModel()` is embedded in a function in third-party code, so you can't control its instantiation? Otherwise why not just subclass `MyModel`? – dROOOze Aug 16 '23 at 05:57
  • @dROOOze Because I want to apply the pre-processing to all PyTorch models and not just `MyModel`. More precisely, I am given a variable `model`, which is an instance of `torch.nn.Module` but its type can be any subclass of `torch.nn.Module` (e.g. `MyModel`). – f10w Aug 16 '23 at 06:02
  • So, do the duplicates linked above solve your question? If not, how don't they? – deceze Aug 16 '23 at 06:12

0 Answers0