3

I am training a model in pytorch for which I have made a class like so:

from torch import nn

class myNN(nn.Module):
    def __init__(self, dense1=128, dense2=64, dense3=32, ...):
         self.MLP = nn.Sequential(
            nn.Linear(dense1, dense2),
            nn.ReLU(),
            nn.Linear(dense2, dense3),
            nn.ReLU(),
            nn.Linear(dense3, 1)
         )
         ...

In order to save it I am using:

torch.save(model.state_dict(), checkpoint_model_path)

and to load it I am using:

model = myNN()   # or with specified parameters
model.load_state_dict(torch.load(model_file))

However, in order for this method to work I have to use the right values in myNN()'s constructor. That means that I would need to somehow remember or store which parameters (layer sizes) I have used in each case in order to properly load different models.

Is there a flexible way to save/load models in pytorch where I would also read the size of the layers?

E.g. by loading a myNN() object directly or somehow reading the layer sizes from the saved pickle file?

I am hesitant to try the second method in Best way to save a trained model in PyTorch? due to the warnings mentioned there. Is there a better way to achieve what I want?

Michael
  • 325
  • 3
  • 14

1 Answers1

11

Indeed serializing the whole Python is quite a drastic move. Instead, you can always add user-defined items in the saved file: you can save the model's state along with its class parameters. Something like this would work:

  1. First save your arguments in the instance such that we can serialize them when saving the model:

    class myNN(nn.Module):
        def __init__(self, dense1=128, dense2=64, dense3=32):
            super().__init__()
            self.kwargs = {'dense1': dense1, 'dense2': dense2, 'dense3': dense3}
            self.MLP = nn.Sequential(
                nn.Linear(dense1, dense2),
                nn.ReLU(),
                nn.Linear(dense2, dense3),
                nn.ReLU(),
                nn.Linear(dense3, 1))
    
  2. We can save the parameters of the model along with its initializer arguments:

    >>> torch.save([model.kwargs, model.state_dict()], path)
    
  3. Then load it:

    >>> kwargs, state = torch.load(path)
    >>> model = myNN(**kwargs)
    >>> model.load_state_dict(state)
    <All keys matched successfully>
    
Ivan
  • 34,531
  • 8
  • 55
  • 100