1

During prototyping, I often perform numerous changes to a PyTorch model. For instance, suppose the first model I am experimenting with is:

class Model(nn.Module):
    def __init__(self, **args):
        super().__init__()
        self.l1 = nn.Linear(128, 1)

then I will add another layer:

class Model(nn.Module):
    def __init__(self, **args):
        super().__init__()
        self.l1 = nn.Linear(128, 32)
        self.l2 = nn.Linear(32, 1)

or maybe add some convolutions, and so on.

The problem is that I often get disorganized the more experiments I perform, as I haven't found a straightforward way of saving both model definition and its weights so that I can load a previous state.

I know I can do:

torch.save({'model': Model(), 'state': model.state_dict()}, path)
# or directly
torch.save(model, path)

but then loading the model also requires that the model class (here, Model) to exist in the current file.

In Keras you can simply do:

model = ...  # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')

which saves model's architecture/config and weights, among other things. This means that you can load the model without having defined the architecture:

model = keras.models.load_model('path/to/location')

Refering to Keras model saving:

The SavedModel and HDF5 file contains:

  • the model's configuration (topology)
  • the model's weights
  • the model's optimizer's state (if any)

Thus models can be reinstantiated in the exact same state, without any of the code used for model definition or training.

This is what I want to achieve in PyTorch.

Is there a similar approach for PyTorch? What is the best practice for these situations?

Alexandru Dinu
  • 1,159
  • 13
  • 24
  • 1
    Does this answer your question? [Best way to save a trained model in PyTorch?](https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch) – Datsheep Mar 16 '21 at 09:39
  • 3
    Sorry to be a little blunt here: best practice for these situations is to remain organised. As you're performing scientific experiments, rigour and organisation is really important. You can use pickle but that will still demand organisation as it needs to be able to reach the classes/functions you've defined for the saved model. – D Hudson Mar 16 '21 at 09:40
  • 1
    @DHudson - that's a valid answer! I thought that maybe I am ignoring or misusing a feature that PyTorch provides (since in Keras this is possible). – Alexandru Dinu Mar 16 '21 at 09:41
  • @Datsheep - this requires that the `Model` class to be defined before loading, and this is what I want to avoid, if possible, since the original definition may not be available anymore. – Alexandru Dinu Mar 16 '21 at 09:42
  • 1
    You could try saving a file with parameters describing the model in the same directory as the model. That could help you to stay organized. – NNN Mar 16 '21 at 11:09
  • @Nachiket - right, but when the architecture changes drastically, this approach may become infeasible. – Alexandru Dinu Mar 16 '21 at 11:28

3 Answers3

3

As Pytorch provides a huge amount of flexibility in the model, it will be challenging to save the architecture along with the weights in a single file. Keras models are usually built solely by stacking keras components, but pytorch models are orchestrated by the library consumer in their own way and therefore can contain any sort of logic.

I think you have three choices:

  1. Come up with a organised schema for your experiments so that losing the model definition is less likely. You could go for something as simple as a file named through a schema that only defines each model. I would recommend this approach as this level of organisation would likely benefit your prototyping in other ways and the overhead is minimal.

  2. Try and save the code along with the pickle file. Although potentially possible, I think this would lead you down a rabbit-hole with a lot of potential problems.

  3. Use a different standardised way of saving the model, such as onnx. I would recommend this route if you do not want to go with option 1. Onnx does allow you to save a pytorch model's architecture along with its weights but comes with a few drawbacks. For example, it only supports some operations so completely custom forward methods or use of non-matrix operations may not work.

D Hudson
  • 1,004
  • 5
  • 12
  • **4. Use `torch.jit` to compile the model into what is called TorchScript.** More details in my answer: https://stackoverflow.com/questions/66652447/pytorch-saving-both-weights-and-model-definition/71392808#71392808 – elgehelge Mar 08 '22 at 09:38
  • check this https://stackoverflow.com/a/75774485/13332582 – Prajot Kuvalekar Mar 18 '23 at 07:53
1

@D Hudson's answer is the right way to go. However, for future reference, I want to add the following methodology which worked for me.

Let's assume that the forward method of the model is fixed, that is, only the underlying architecture is changed, same input & output shapes. In this case, we are only interested in the Sequential attribute that represents the entire architecture:

class Model(nn.Module):
    def __init__(self, **hparams):
        super(Model, self).__init__()
        
        # this attribute is the only thing we care about
        self.net = nn.Sequential(
            # experiment with different layers here ...
        )
        
    def forward(self, x):
        return self.net(x) # this is fixed!

Then, we can save the model architecture (essentially just the net attribute), and its weights as such:

m = Model()
# train/test/valid ...
T.save({'net': m.net, 'weights': m.state_dict()}, './version1.pth')

And finally, loading is performed as such:

m = Model()
checkpoint = T.load('./version1.pth')
m.net = checkpoint['net']
m.load_state_dict(checkpoint['weights'])
Alexandru Dinu
  • 1,159
  • 13
  • 24
1

PyTorch's way of serializing a model (both architecture and weights) for later inference is to use torch.jit to compile the model to TorchScript.

Serialization can happen either through tracing (torch.jit.trace) or compiling the Python model code (torch.jit.script). Here are some great references:

elgehelge
  • 2,014
  • 1
  • 19
  • 24