0

Is it possible to load a pytorch model (from a .pth file, containing architecture+state_dict) without torchvision as a dependency?

import os
import torch
assert os.path.exists(r'.\vgg.pth')
model = torch.load(r'.\vgg.pth')

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-4-e26863d95688> in <module>
      2 import torch
      3 assert os.path.exists(r'.\vgg.pth')
----> 4 model = torch.load(r'.\vgg.pth')

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    590                     opened_file.seek(orig_position)
    591                     return torch.jit.load(opened_file)
--> 592                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    593         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    594 

~\Anaconda3\envs\pytorch_save\lib\site-packages\torch\serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
    849     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
    850     unpickler.persistent_load = persistent_load
--> 851     result = unpickler.load()
    852 
    853     torch._utils._validate_loaded_sparse_tensors()

ModuleNotFoundError: No module named 'torchvision'

I have looked into torch/serialization.py, but I see no reason why it would need torchvision. The imports in this file are as follows:

import difflib
import os
import io
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
import copyreg
import pickle
import pathlib
Jonas De Schouwer
  • 755
  • 1
  • 9
  • 15
  • 2
    `torchvision` is a dependency for the `vgg` model which you are trying to load, it is not related to `serialization`. It is not possible to load the `vgg` model without `torchvision`. – Kishore Sampath Aug 17 '21 at 08:08
  • @Kishore this is completely false, there is no need for the `torchvision.vgg` model to load a file named `vgg.pth`. Files loaded with `torch.load` contain model *dict*s **not** module definitions... Why would you even assume it's `torchvision`'s implementation that's contained in `vgg.pth`? This makes no sense. – Ivan Aug 17 '21 at 08:31
  • Are you providing the whole code or is there something else running? – Ivan Aug 17 '21 at 08:37
  • @Ivan actually this torch.load() does load the model definition. In Pytorch, you can choose either to only save the model state_dict, or architecture+state_dict. I needed the latter. Also, there is nothing else running. – Jonas De Schouwer Aug 17 '21 at 08:48
  • *you can choose either to only save the model state_dict, or architecture+state_dict*, how so? How did you save the model's definition into the file? – Ivan Aug 17 '21 at 09:01
  • 2
    @Ivan You can save the entire model using `torch.save(the_model, PATH)` and then load it using `the_model = torch.load(PATH)`. This is similar to saving the entire model in `tensorflow` or just `pickling` the model. Refer [here](https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch). And answering to your question *"Why would you even assume it's torchvision's implementation that's contained in vgg.pth"* What else do you think is causing the error, it's straightforward. – Kishore Sampath Aug 17 '21 at 09:54
  • @Kishore Fair enough, I wasn't aware of this feature. By *"Why would you even assume it's torchvision's implementation that's contained in vgg.pth"* I mean: how do you know the model that was saved is a torchvision.vgg? Any type of model saved that was using torchvision (by using its model loader **or not*) would throw this error upon loading, correct? So saying '*`vgg` model which you are trying to load*' is just speculating on the content of `vgg.pth`... Anyhow, this only shows why it's generally prefered to save the model's dict, not the model itself, because you don't know its deps. – Ivan Aug 17 '21 at 10:12
  • 2
    @Ivan From the error stack we can see that `ModuleNotFoundError` is generated from the line `result = unpickler.load()`. When `pytorch` tries to load the saved model from the `.pth` file, the `model` which was saved must have used `torchvision` as a dependency and as we don't have `torchvision` installed, 'ModuleNotFoundError' was raised. And I also accept that saving the model parameters in a dictionary is better way of serializing a model. – Kishore Sampath Aug 17 '21 at 10:42
  • @Kishore I'm not saying `torchvision` isn't a dependency, it is indeed. Just that it's not possible to infer that the model in question is `torchvision`'s `vgg`. It might as well be any model (using `torchvision` as a dependency). That we don't know. – Ivan Aug 17 '21 at 10:57

1 Answers1

0

What caused my problem

The vgg.pth file in my question was generated as follows:

import torchvision
vgg = models.vgg16(pretrained=True, init_weights=False)
torch.save(vgg, r'.\vgg.pth')

This way, the file vgg.pth contains not only the model parameters, but also the model architecture (see pytorch: save/load entire model). However, as @Kishore pointed out in the comments, it seems that this architecture also needs torchvision as a dependency.

How I solved it

  • In an environment with torchvision, I loaded the pretrained VGG model into memory and saved the state_dict
from torchvision.models.vgg import vgg16
import torch

model = vgg16(pretrained=True)
torch.save(model.state_dict, r'.\state_dict.pth')
  • In an environment without torchvision, I rebuilt the model by inspecting the torchvision.models.vgg code.
    Then I loaded this state_dict file into the state_dict of my model.
    Lastly, I saved this model (including architecture) to a .pth file.
import torch

# a file where I pasted the torchvision.models.vgg code
# and commented out the torchvision dependencies I don't need
# in this case: 'from .._internally_replaced_utils import load_state_dict_from_url'
from torch_save import *

model = vgg16()
model.load_state_dict(torch.load(r'.\state_dict.pth'))
torch.save(model, r'.\entire_model.pth')

When I load this again in a torchvision-free environment, I get no errors.

Jonas De Schouwer
  • 755
  • 1
  • 9
  • 15