23

How can I save a PyTorch model without a need for the model class to be defined somewhere?


Disclaimer:

In Best way to save a trained model in PyTorch?, there are no solutions (or a working solution) for saving the model without access to the model class code.

rayryeng
  • 102,964
  • 22
  • 184
  • 193
Michael D
  • 1,711
  • 4
  • 23
  • 38

4 Answers4

32

If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.

I think the simplest thing is to use trace = torch.jit.trace(model, typical_input) and then torch.jit.save(trace, path). You can then load the traced model with torch.jit.load(path).

Here's a really simple example. We make two files:

train.py :

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))

Running these sequentially gives results:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).

TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.

nlml
  • 1,395
  • 12
  • 19
  • What are cons of using torch script? – Michael D Dec 18 '19 at 17:45
  • 4
    Well the main con is you still need a pytorch environment of some sort. Also if you wanted to keep training a trace, I imagine that would be very difficult/impossible. It can also be a bit buggy/hard to debug at times. But it's basically pytorch's answer to ease of saving the entire graph in tensorflow. It's improving with every release and is already very good imo. – nlml Dec 18 '19 at 18:48
  • How do I get the resolution of an input image from a model saved this way? – Mathews Edwirds Nov 07 '22 at 13:16
  • 1
    @MathewsEdwirds in general it is not possible. You should save this information as metadata or in a config file that accompanies the model weights you exported. Also try opening the model in netron.app to see what info is available inside the model. – nlml Nov 08 '22 at 14:03
  • Thanks man, I decided to save the input resolution in the model file name. – Mathews Edwirds Nov 09 '22 at 12:27
2

I recomend you to convert you pytorch model to onnx and save it. Probably its best way to store model without an access to the class.

  • 1
    I'm not excited of onnx, since it has some limitations....and doesn't support some pytorch functionality.... – Michael D Dec 15 '19 at 10:04
1

Supplying an official answer by one of the core PyTorch devs (smth):

There are limitations to loading a pytorch model without code.

First limitation: We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).

For example:

import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

Here the package foo is not saved in the model checkpoint.

Second limitation: There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill 25 is one such package.

Given these limitations, there is no robust way to have torch.load work without having the original source files.

Mano
  • 797
  • 3
  • 17
0

There is no a solutins (or working solution) for saving model without an access to the class.

You can save whatever you like.

You can save the model, torch.save(model, filepath). It saves the model object itself.

You can save just the model state dict.

torch.save(model.state_dict(), filepath)

Further, you can save anything you like, since torch.save is just a pickle based save.

state = {
    'hello_text': 'just the optimizer sd will be saved',
    'optimizer': optimizer.state_dict(),

}
torch.save(state, filepath)

You may check what I wrote on torch.save some time ago.

prosti
  • 42,291
  • 14
  • 186
  • 151
  • I would expect to exist, some workaround, since there is such option in tensorflow. – Michael D Dec 11 '19 at 18:23
  • @prosti this does not answer the question, which is perhaps a bit ill-formulated. Both your options still require the model class to be defined when calling `torch.load` or `.load_state_dict`. The question is about finding a method that allows to load the saved representation of the model _without_ access to its class definition (which is straightforward in TensorFlow for example). – Thomas Wagenaar Apr 12 '23 at 11:12