2

I'm using vgg19 in a classification problem. I have access to the campus research computer to train on, but the nodes where the computation is done don't have access to the internet. So running a line of code like self.net = models.vgg19(pretrained=True) fails with the error urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>

Is there a way I could cache the model on the head node (where I have internet access), and load the model from the cache instead of the internet on the compute node?

unlut
  • 3,525
  • 2
  • 14
  • 23
  • 1
    Does this answer your question? [Is there any way I can download the pre-trained models available in PyTorch to a specific path?](https://stackoverflow.com/questions/52628270/is-there-any-way-i-can-download-the-pre-trained-models-available-in-pytorch-to-a) – kHarshit Feb 20 '20 at 05:43
  • It's similar, but not the same. That question is more concerned with the path it's downloaded to. I'm concerned with how to load it after it's cached, which is what @unlut explained. But the link did provide good information. I hadn't realized it only fetches the model from the URL if it's pretrained. – Joseph Summerhays Feb 20 '20 at 15:22

1 Answers1

4

If you just save the weights of pretrained networks somewhere, you can load them just like you can load any other network weights.

Saving:

import torchvision

#  I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")

Loading:

import torchvision

def create_vgg16(dict_path=None):
    model = torchvision.models.vgg16(pretrained=False)
    if (dict_path != None):
        model.load_state_dict(torch.load(dict_path))
    return model

model = create_vgg16("Somewhere")
unlut
  • 3,525
  • 2
  • 14
  • 23